mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-22 14:48:13 +00:00
Merge branch 'main' into lg-allow-remove-stop
This commit is contained in:
@@ -3,12 +3,13 @@ import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
|
||||
|
||||
@@ -23,17 +24,20 @@ def in_temp_dir():
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess):
|
||||
with in_temp_dir():
|
||||
tool_command = ToolCommand()
|
||||
@pytest.fixture
|
||||
def tool_command():
|
||||
TokenManager().save_tokens("test-token", 36000)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
|
||||
with (
|
||||
patch.object(tool_command, "login") as mock_login,
|
||||
patch("sys.stdout", new=StringIO()) as fake_out,
|
||||
):
|
||||
tool_command.create("test-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess, capsys, tool_command):
|
||||
with in_temp_dir():
|
||||
tool_command.create("test-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
assert os.path.isdir("test_tool")
|
||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||
@@ -47,15 +51,12 @@ def test_create_success(mock_subprocess):
|
||||
content = f.read()
|
||||
assert "class TestTool" in content
|
||||
|
||||
mock_login.assert_called_once()
|
||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(mock_get, mock_subprocess_run):
|
||||
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
@@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
tool_command.install("sample-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||
mock_subprocess_run.assert_any_call(
|
||||
@@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run):
|
||||
env=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(mock_get):
|
||||
def test_install_tool_not_found(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("non-existent-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(mock_get):
|
||||
def test_install_api_error(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.install("error-tool")
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit):
|
||||
tool_command = ToolCommand()
|
||||
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
|
||||
output = capsys.readouterr().out
|
||||
assert "Local changes need to be resolved before publishing" in output
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True, force=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -205,13 +192,13 @@ def test_publish_success(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
@@ -251,25 +238,22 @@ def test_publish_failure(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@@ -290,6 +274,8 @@ def test_publish_api_error(
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -297,14 +283,9 @@ def test_publish_api_error(
|
||||
mock_response.ok = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
try:
|
||||
tool_command.publish(is_public=True)
|
||||
except SystemExit:
|
||||
pass
|
||||
output = fake_out.getvalue()
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow import Flow, listen, start
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
@@ -2164,7 +2165,6 @@ def test_tools_with_custom_caching():
|
||||
with patch.object(
|
||||
CacheHandler, "add", wraps=crew._cache_handler.add
|
||||
) as add_to_cache:
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Check that add_to_cache was called exactly twice
|
||||
@@ -4351,3 +4351,35 @@ def test_crew_copy_with_memory():
|
||||
raise e # Re-raise other validation errors
|
||||
except Exception as e:
|
||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
assert crew.parent_flow is None
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
return Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(
|
||||
description="Task 1", expected_output="output", agent=researcher
|
||||
),
|
||||
Task(description="Task 2", expected_output="output", agent=writer),
|
||||
],
|
||||
)
|
||||
|
||||
flow = MyFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.parent_flow is flow
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import asyncio
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
|
||||
|
||||
|
||||
@@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async():
|
||||
assert "21 million" in result.raw or "37 million" in result.raw
|
||||
assert result.usage_metrics is not None
|
||||
assert result.usage_metrics["total_tokens"] > 0
|
||||
|
||||
|
||||
class TestFlow(Flow):
|
||||
"""A test flow that creates and runs an agent."""
|
||||
|
||||
def __init__(self, llm, tools):
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
super().__init__()
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
|
||||
def verify_agent_parent_flow(result, agent, flow):
|
||||
"""Verify that both the result and agent have the correct parent flow."""
|
||||
assert result.parent_flow is flow
|
||||
assert agent is not None
|
||||
assert agent.parent_flow is flow
|
||||
|
||||
|
||||
def test_sets_parent_flow_when_inside_flow():
|
||||
captured_agent = None
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Test response"
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test Goal",
|
||||
backstory="Test Backstory",
|
||||
llm=mock_llm,
|
||||
tools=[WebSearchTool()],
|
||||
)
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
flow = MyFlow()
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
|
||||
result = flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
Reference in New Issue
Block a user