Identify parent_flow of Crew and LiteAgent (#2723)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit adds a new crew field called parent_flow, evaluated when the Crew
instance is instantiated. The stacktrace is traversed to look up if the caller
is an instance of Flow, and if so, it fills in the field.

Other alternatives were considered, such as a global context or even a new
field to be manually filled, however, this is the most magical solution that
was thread-safe and did not require public API changes.
This commit is contained in:
Vini Brasil
2025-05-02 14:40:39 -03:00
committed by GitHub
parent f89c2bfb7e
commit 17474a3a0c
6 changed files with 201 additions and 72 deletions

View File

@@ -6,7 +6,17 @@ import warnings
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy as shallow_copy from copy import copy as shallow_copy
from hashlib import md5 from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from pydantic import ( from pydantic import (
UUID4, UUID4,
@@ -24,6 +34,7 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM, BaseLLM from crewai.llm import LLM, BaseLLM
@@ -69,7 +80,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
class Crew(BaseModel): class Crew(FlowTrackable, BaseModel):
""" """
Represents a group of agents, defining how they should collaborate and the tasks they should perform. Represents a group of agents, defining how they should collaborate and the tasks they should perform.

View File

@@ -0,0 +1,44 @@
import inspect
from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator
from crewai.flow import Flow
class FlowTrackable(BaseModel):
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
Flow instance that created a Crew or Agent.
Automatically finds and stores a reference to the parent Flow instance by
inspecting the call stack.
"""
parent_flow: Optional[InstanceOf[Flow]] = Field(
default=None,
description="The parent flow of the instance, if it was created inside a flow.",
)
@model_validator(mode="after")
def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable":
frame = inspect.currentframe()
try:
if frame is None:
return self
frame = frame.f_back
for _ in range(max_depth):
if frame is None:
break
candidate = frame.f_locals.get("self")
if isinstance(candidate, Flow):
self.parent_flow = candidate
break
frame = frame.f_back
finally:
del frame
return self

View File

@@ -13,6 +13,7 @@ from crewai.agents.parser import (
AgentFinish, AgentFinish,
OutputParserException, OutputParserException,
) )
from crewai.flow.flow_trackable import FlowTrackable
from crewai.llm import LLM from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
@@ -80,7 +81,7 @@ class LiteAgentOutput(BaseModel):
return self.raw return self.raw
class LiteAgent(BaseModel): class LiteAgent(FlowTrackable, BaseModel):
""" """
A lightweight agent that can process messages and use tools. A lightweight agent that can process messages and use tools.

View File

@@ -3,12 +3,13 @@ import tempfile
import unittest import unittest
import unittest.mock import unittest.mock
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from pytest import raises from pytest import raises
from crewai.cli.authentication.utils import TokenManager
from crewai.cli.tools.main import ToolCommand from crewai.cli.tools.main import ToolCommand
@@ -23,17 +24,20 @@ def in_temp_dir():
os.chdir(original_dir) os.chdir(original_dir)
@patch("crewai.cli.tools.main.subprocess.run") @pytest.fixture
def test_create_success(mock_subprocess): def tool_command():
with in_temp_dir(): TokenManager().save_tokens("test-token", 36000)
tool_command = ToolCommand() tool_command = ToolCommand()
with patch.object(tool_command, "login"):
yield tool_command
with (
patch.object(tool_command, "login") as mock_login, @patch("crewai.cli.tools.main.subprocess.run")
patch("sys.stdout", new=StringIO()) as fake_out, def test_create_success(mock_subprocess, capsys, tool_command):
): with in_temp_dir():
tool_command.create("test-tool") tool_command.create("test-tool")
output = fake_out.getvalue() output = capsys.readouterr().out
assert "Creating custom tool test_tool..." in output
assert os.path.isdir("test_tool") assert os.path.isdir("test_tool")
assert os.path.isfile(os.path.join("test_tool", "README.md")) assert os.path.isfile(os.path.join("test_tool", "README.md"))
@@ -47,15 +51,12 @@ def test_create_success(mock_subprocess):
content = f.read() content = f.read()
assert "class TestTool" in content assert "class TestTool" in content
mock_login.assert_called_once()
mock_subprocess.assert_called_once_with(["git", "init"], check=True) mock_subprocess.assert_called_once_with(["git", "init"], check=True)
assert "Creating custom tool test_tool..." in output
@patch("crewai.cli.tools.main.subprocess.run") @patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.plus_api.PlusAPI.get_tool") @patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_success(mock_get, mock_subprocess_run): def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
mock_get_response = MagicMock() mock_get_response = MagicMock()
mock_get_response.status_code = 200 mock_get_response.status_code = 200
mock_get_response.json.return_value = { mock_get_response.json.return_value = {
@@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run):
mock_get.return_value = mock_get_response mock_get.return_value = mock_get_response
mock_subprocess_run.return_value = MagicMock(stderr=None) mock_subprocess_run.return_value = MagicMock(stderr=None)
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
tool_command.install("sample-tool") tool_command.install("sample-tool")
output = fake_out.getvalue() output = capsys.readouterr().out
assert "Successfully installed sample-tool" in output
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()]) mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
mock_subprocess_run.assert_any_call( mock_subprocess_run.assert_any_call(
@@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run):
env=unittest.mock.ANY, env=unittest.mock.ANY,
) )
assert "Successfully installed sample-tool" in output
@patch("crewai.cli.plus_api.PlusAPI.get_tool") @patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_tool_not_found(mock_get): def test_install_tool_not_found(mock_get, capsys, tool_command):
mock_get_response = MagicMock() mock_get_response = MagicMock()
mock_get_response.status_code = 404 mock_get_response.status_code = 404
mock_get.return_value = mock_get_response mock_get.return_value = mock_get_response
tool_command = ToolCommand() with raises(SystemExit):
with patch("sys.stdout", new=StringIO()) as fake_out:
try:
tool_command.install("non-existent-tool") tool_command.install("non-existent-tool")
except SystemExit: output = capsys.readouterr().out
pass assert "No tool found with this name" in output
output = fake_out.getvalue()
mock_get.assert_called_once_with("non-existent-tool") mock_get.assert_called_once_with("non-existent-tool")
assert "No tool found with this name" in output
@patch("crewai.cli.plus_api.PlusAPI.get_tool") @patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_api_error(mock_get): def test_install_api_error(mock_get, capsys, tool_command):
mock_get_response = MagicMock() mock_get_response = MagicMock()
mock_get_response.status_code = 500 mock_get_response.status_code = 500
mock_get.return_value = mock_get_response mock_get.return_value = mock_get_response
tool_command = ToolCommand() with raises(SystemExit):
with patch("sys.stdout", new=StringIO()) as fake_out:
try:
tool_command.install("error-tool") tool_command.install("error-tool")
except SystemExit: output = capsys.readouterr().out
pass assert "Failed to get tool details" in output
output = fake_out.getvalue()
mock_get.assert_called_once_with("error-tool") mock_get.assert_called_once_with("error-tool")
assert "Failed to get tool details" in output
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False) @patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
def test_publish_when_not_in_sync(mock_is_synced): def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit): with raises(SystemExit):
tool_command = ToolCommand()
tool_command.publish(is_public=True) tool_command.publish(is_public=True)
assert "Local changes need to be resolved before publishing" in fake_out.getvalue() output = capsys.readouterr().out
assert "Local changes need to be resolved before publishing" in output
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force(
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
tool_command,
): ):
mock_publish_response = MagicMock() mock_publish_response = MagicMock()
mock_publish_response.status_code = 200 mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish_response.json.return_value = {"handle": "sample-tool"}
mock_publish.return_value = mock_publish_response mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
tool_command.publish(is_public=True, force=True) tool_command.publish(is_public=True, force=True)
mock_get_project_name.assert_called_with(require=True) mock_get_project_name.assert_called_with(require=True)
@@ -205,13 +192,13 @@ def test_publish_success(
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
tool_command,
): ):
mock_publish_response = MagicMock() mock_publish_response = MagicMock()
mock_publish_response.status_code = 200 mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish_response.json.return_value = {"handle": "sample-tool"}
mock_publish.return_value = mock_publish_response mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
tool_command.publish(is_public=True) tool_command.publish(is_public=True)
mock_get_project_name.assert_called_with(require=True) mock_get_project_name.assert_called_with(require=True)
@@ -251,25 +238,22 @@ def test_publish_failure(
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
capsys,
tool_command,
): ):
mock_publish_response = MagicMock() mock_publish_response = MagicMock()
mock_publish_response.status_code = 422 mock_publish_response.status_code = 422
mock_publish_response.json.return_value = {"name": ["is already taken"]} mock_publish_response.json.return_value = {"name": ["is already taken"]}
mock_publish.return_value = mock_publish_response mock_publish.return_value = mock_publish_response
tool_command = ToolCommand() with raises(SystemExit):
with patch("sys.stdout", new=StringIO()) as fake_out:
try:
tool_command.publish(is_public=True) tool_command.publish(is_public=True)
except SystemExit: output = capsys.readouterr().out
pass
output = fake_out.getvalue()
mock_publish.assert_called_once()
assert "Failed to complete operation" in output assert "Failed to complete operation" in output
assert "Name is already taken" in output assert "Name is already taken" in output
mock_publish.assert_called_once()
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@@ -290,6 +274,8 @@ def test_publish_api_error(
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
capsys,
tool_command,
): ):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 500 mock_response.status_code = 500
@@ -297,14 +283,9 @@ def test_publish_api_error(
mock_response.ok = False mock_response.ok = False
mock_publish.return_value = mock_response mock_publish.return_value = mock_response
tool_command = ToolCommand() with raises(SystemExit):
with patch("sys.stdout", new=StringIO()) as fake_out:
try:
tool_command.publish(is_public=True) tool_command.publish(is_public=True)
except SystemExit: output = capsys.readouterr().out
pass assert "Request to Enterprise API failed" in output
output = fake_out.getvalue()
mock_publish.assert_called_once() mock_publish.assert_called_once()
assert "Request to Enterprise API failed" in output

View File

@@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.crew import Crew from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.flow import Flow, listen, start
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.memory.contextual.contextual_memory import ContextualMemory
@@ -2164,7 +2165,6 @@ def test_tools_with_custom_caching():
with patch.object( with patch.object(
CacheHandler, "add", wraps=crew._cache_handler.add CacheHandler, "add", wraps=crew._cache_handler.add
) as add_to_cache: ) as add_to_cache:
result = crew.kickoff() result = crew.kickoff()
# Check that add_to_cache was called exactly twice # Check that add_to_cache was called exactly twice
@@ -4351,3 +4351,35 @@ def test_crew_copy_with_memory():
raise e # Re-raise other validation errors raise e # Re-raise other validation errors
except Exception as e: except Exception as e:
pytest.fail(f"Copying crew raised an unexpected exception: {e}") pytest.fail(f"Copying crew raised an unexpected exception: {e}")
def test_sets_parent_flow_when_outside_flow(researcher, writer):
crew = Crew(
agents=[researcher, writer],
process=Process.sequential,
tasks=[
Task(description="Task 1", expected_output="output", agent=researcher),
Task(description="Task 2", expected_output="output", agent=writer),
],
)
assert crew.parent_flow is None
def test_sets_parent_flow_when_inside_flow(researcher, writer):
class MyFlow(Flow):
@start()
def start(self):
return Crew(
agents=[researcher, writer],
process=Process.sequential,
tasks=[
Task(
description="Task 1", expected_output="output", agent=researcher
),
Task(description="Task 2", expected_output="output", agent=writer),
],
)
flow = MyFlow()
result = flow.kickoff()
assert result.parent_flow is flow

View File

@@ -1,13 +1,16 @@
import asyncio import asyncio
from typing import cast from typing import cast
from unittest.mock import Mock
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai import LLM, Agent from crewai import LLM, Agent
from crewai.flow import Flow, start
from crewai.lite_agent import LiteAgent, LiteAgentOutput from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.tools import BaseTool from crewai.tools import BaseTool
from crewai.utilities.events import crewai_event_bus from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
@@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async():
assert "21 million" in result.raw or "37 million" in result.raw assert "21 million" in result.raw or "37 million" in result.raw
assert result.usage_metrics is not None assert result.usage_metrics is not None
assert result.usage_metrics["total_tokens"] > 0 assert result.usage_metrics["total_tokens"] > 0
class TestFlow(Flow):
"""A test flow that creates and runs an agent."""
def __init__(self, llm, tools):
self.llm = llm
self.tools = tools
super().__init__()
@start()
def start(self):
agent = Agent(
role="Test Agent",
goal="Test Goal",
backstory="Test Backstory",
llm=self.llm,
tools=self.tools,
)
return agent.kickoff("Test query")
def verify_agent_parent_flow(result, agent, flow):
"""Verify that both the result and agent have the correct parent flow."""
assert result.parent_flow is flow
assert agent is not None
assert agent.parent_flow is flow
def test_sets_parent_flow_when_inside_flow():
captured_agent = None
mock_llm = Mock(spec=LLM)
mock_llm.call.return_value = "Test response"
class MyFlow(Flow):
@start()
def start(self):
agent = Agent(
role="Test Agent",
goal="Test Goal",
backstory="Test Backstory",
llm=mock_llm,
tools=[WebSearchTool()],
)
return agent.kickoff("Test query")
flow = MyFlow()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
def capture_agent(source, event):
nonlocal captured_agent
captured_agent = source
result = flow.kickoff()
assert captured_agent.parent_flow is flow