From 6464e845499697e956c47b963f86fe8ceb90ce32 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Sat, 10 May 2025 08:27:01 -0300 Subject: [PATCH] feat: support to load an Agent from a repository --- src/crewai/agent.py | 44 +++++++++++++ src/crewai/cli/authentication/token.py | 2 +- src/crewai/cli/plus_api.py | 4 ++ src/crewai/utilities/errors.py | 7 ++ tests/agent_test.py | 90 ++++++++++++++++++++++++-- 5 files changed, 141 insertions(+), 6 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index dc637967f..9a16b7047 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -25,6 +25,7 @@ from crewai.utilities.agent_utils import ( ) from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.converter import generate_model_description +from crewai.utilities.errors import AgentRepositoryError from crewai.utilities.events.agent_events import ( AgentExecutionCompletedEvent, AgentExecutionErrorEvent, @@ -134,6 +135,49 @@ class Agent(BaseAgent): default=None, description="Knowledge search query for the agent dynamically generated by the agent.", ) + from_repository: Optional[str] = Field( + default=None, + description="The Agent's role to be used from your repository.", + ) + + @model_validator(mode="before") + def validate_from_repository(cls, v): + if v is not None and (from_repository := v.get("from_repository")): + return cls._load_agent_from_repository(from_repository) | v + return v + + @classmethod + def _load_agent_from_repository(cls, from_repository: str) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + if from_repository: + import importlib + + from crewai.cli.authentication.token import get_auth_token + from crewai.cli.plus_api import PlusAPI + + client = PlusAPI(api_key=get_auth_token()) + response = client.get_agent(from_repository) + if response.status_code != 200: + raise AgentRepositoryError( + f"Agent {from_repository} could not be loaded: {response.text}" + ) + + agent = response.json() + for key, value in agent.items(): + if key == "tools": + attributes[key] = [] + for tool_name in value: + try: + module = importlib.import_module("crewai_tools") + tool_class = getattr(module, tool_name) + attributes[key].append(tool_class()) + except Exception as e: + raise AgentRepositoryError( + f"Tool {tool_name} could not be loaded: {e}" + ) from e + else: + attributes[key] = value + return attributes @model_validator(mode="after") def post_init_setup(self): diff --git a/src/crewai/cli/authentication/token.py b/src/crewai/cli/authentication/token.py index 30a33b4ba..46ff75e89 100644 --- a/src/crewai/cli/authentication/token.py +++ b/src/crewai/cli/authentication/token.py @@ -5,5 +5,5 @@ def get_auth_token() -> str: """Get the authentication token.""" access_token = TokenManager().get_token() if not access_token: - raise Exception() + raise Exception("No token found, make sure you are logged in") return access_token diff --git a/src/crewai/cli/plus_api.py b/src/crewai/cli/plus_api.py index 23032ca8f..93e5750c8 100644 --- a/src/crewai/cli/plus_api.py +++ b/src/crewai/cli/plus_api.py @@ -14,6 +14,7 @@ class PlusAPI: TOOLS_RESOURCE = "/crewai_plus/api/v1/tools" CREWS_RESOURCE = "/crewai_plus/api/v1/crews" + AGENTS_RESOURCE = "/crewai_plus/api/v1/agents" def __init__(self, api_key: str) -> None: self.api_key = api_key @@ -37,6 +38,9 @@ class PlusAPI: def get_tool(self, handle: str): return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}") + def get_agent(self, handle: str): + return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}") + def publish_tool( self, handle: str, diff --git a/src/crewai/utilities/errors.py b/src/crewai/utilities/errors.py index f673c0600..16c59321e 100644 --- a/src/crewai/utilities/errors.py +++ b/src/crewai/utilities/errors.py @@ -1,4 +1,5 @@ """Error message definitions for CrewAI database operations.""" + from typing import Optional @@ -37,3 +38,9 @@ class DatabaseError: The formatted error message """ return template.format(str(error)) + + +class AgentRepositoryError(Exception): + """Exception raised when an agent repository is not found.""" + + ... diff --git a/tests/agent_test.py b/tests/agent_test.py index faad4ca84..968761ce9 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -2,7 +2,7 @@ import os from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -18,6 +18,7 @@ from crewai.tools import tool from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_usage import ToolUsage from crewai.utilities import RPMController +from crewai.utilities.errors import AgentRepositoryError from crewai.utilities.events import crewai_event_bus from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent @@ -308,9 +309,7 @@ def test_cache_hitting(): def handle_tool_end(source, event): received_events.append(event) - with ( - patch.object(CacheHandler, "read") as read, - ): + with (patch.object(CacheHandler, "read") as read,): read.return_value = "0" task = Task( description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.", @@ -1040,7 +1039,7 @@ def test_agent_human_input(): CrewAgentExecutor, "_invoke_loop", return_value=AgentFinish(output="Hello", thought="", text=""), - ) as mock_invoke_loop, + ), ): # Execute the task output = agent.execute_task(task) @@ -2025,3 +2024,84 @@ def test_get_knowledge_search_query(): }, ] ) + + +@patch("crewai.cli.plus_api.PlusAPI.get_agent") +def test_agent_from_repository(mock_get_agent): + from crewai_tools import SerperDevTool + + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "role": "test role", + "goal": "test goal", + "backstory": "test backstory", + "tools": ["SerperDevTool"], + } + mock_get_agent.return_value = mock_get_response + agent = Agent(from_repository="test_agent") + + assert agent.role == "test role" + assert agent.goal == "test goal" + assert agent.backstory == "test backstory" + assert len(agent.tools) == 1 + assert isinstance(agent.tools[0], SerperDevTool) + + +@pytest.fixture +def mock_get_auth_token(): + with patch("crewai.cli.command.get_auth_token", return_value="test_token"): + yield + + +@patch("crewai.cli.plus_api.PlusAPI.get_agent") +def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token): + from crewai_tools import SerperDevTool + + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "role": "test role", + "goal": "test goal", + "backstory": "test backstory", + "tools": ["SerperDevTool"], + } + mock_get_agent.return_value = mock_get_response + agent = Agent(from_repository="test_agent", role="Custom Role") + + assert agent.role == "Custom Role" + assert agent.goal == "test goal" + assert agent.backstory == "test backstory" + assert len(agent.tools) == 1 + assert isinstance(agent.tools[0], SerperDevTool) + + +@patch("crewai.cli.plus_api.PlusAPI.get_agent") +def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token): + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "role": "test role", + "goal": "test goal", + "backstory": "test backstory", + "tools": ["DoesNotExist"], + } + mock_get_agent.return_value = mock_get_response + with pytest.raises( + AgentRepositoryError, + match="Tool DoesNotExist could not be loaded: module 'crewai_tools' has no attribute 'DoesNotExist'", + ): + Agent(from_repository="test_agent") + + +@patch("crewai.cli.plus_api.PlusAPI.get_agent") +def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token): + mock_get_response = MagicMock() + mock_get_response.status_code = 404 + mock_get_response.text = "Agent not found" + mock_get_agent.return_value = mock_get_response + with pytest.raises( + AgentRepositoryError, + match="Agent NOT_FOUND could not be loaded: Agent not found", + ): + Agent(from_repository="NOT_FOUND")