mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: support to load an Agent from a repository
This commit is contained in:
@@ -25,6 +25,7 @@ from crewai.utilities.agent_utils import (
|
|||||||
)
|
)
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
from crewai.utilities.errors import AgentRepositoryError
|
||||||
from crewai.utilities.events.agent_events import (
|
from crewai.utilities.events.agent_events import (
|
||||||
AgentExecutionCompletedEvent,
|
AgentExecutionCompletedEvent,
|
||||||
AgentExecutionErrorEvent,
|
AgentExecutionErrorEvent,
|
||||||
@@ -134,6 +135,49 @@ class Agent(BaseAgent):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||||
)
|
)
|
||||||
|
from_repository: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Agent's role to be used from your repository.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def validate_from_repository(cls, v):
|
||||||
|
if v is not None and (from_repository := v.get("from_repository")):
|
||||||
|
return cls._load_agent_from_repository(from_repository) | v
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_agent_from_repository(cls, from_repository: str) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
if from_repository:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from crewai.cli.authentication.token import get_auth_token
|
||||||
|
from crewai.cli.plus_api import PlusAPI
|
||||||
|
|
||||||
|
client = PlusAPI(api_key=get_auth_token())
|
||||||
|
response = client.get_agent(from_repository)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise AgentRepositoryError(
|
||||||
|
f"Agent {from_repository} could not be loaded: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = response.json()
|
||||||
|
for key, value in agent.items():
|
||||||
|
if key == "tools":
|
||||||
|
attributes[key] = []
|
||||||
|
for tool_name in value:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module("crewai_tools")
|
||||||
|
tool_class = getattr(module, tool_name)
|
||||||
|
attributes[key].append(tool_class())
|
||||||
|
except Exception as e:
|
||||||
|
raise AgentRepositoryError(
|
||||||
|
f"Tool {tool_name} could not be loaded: {e}"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
attributes[key] = value
|
||||||
|
return attributes
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
|
|||||||
@@ -5,5 +5,5 @@ def get_auth_token() -> str:
|
|||||||
"""Get the authentication token."""
|
"""Get the authentication token."""
|
||||||
access_token = TokenManager().get_token()
|
access_token = TokenManager().get_token()
|
||||||
if not access_token:
|
if not access_token:
|
||||||
raise Exception()
|
raise Exception("No token found, make sure you are logged in")
|
||||||
return access_token
|
return access_token
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class PlusAPI:
|
|||||||
|
|
||||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||||
|
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||||
|
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -37,6 +38,9 @@ class PlusAPI:
|
|||||||
def get_tool(self, handle: str):
|
def get_tool(self, handle: str):
|
||||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||||
|
|
||||||
|
def get_agent(self, handle: str):
|
||||||
|
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||||
|
|
||||||
def publish_tool(
|
def publish_tool(
|
||||||
self,
|
self,
|
||||||
handle: str,
|
handle: str,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Error message definitions for CrewAI database operations."""
|
"""Error message definitions for CrewAI database operations."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -37,3 +38,9 @@ class DatabaseError:
|
|||||||
The formatted error message
|
The formatted error message
|
||||||
"""
|
"""
|
||||||
return template.format(str(error))
|
return template.format(str(error))
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRepositoryError(Exception):
|
||||||
|
"""Exception raised when an agent repository is not found."""
|
||||||
|
|
||||||
|
...
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -18,6 +18,7 @@ from crewai.tools import tool
|
|||||||
from crewai.tools.tool_calling import InstructorToolCalling
|
from crewai.tools.tool_calling import InstructorToolCalling
|
||||||
from crewai.tools.tool_usage import ToolUsage
|
from crewai.tools.tool_usage import ToolUsage
|
||||||
from crewai.utilities import RPMController
|
from crewai.utilities import RPMController
|
||||||
|
from crewai.utilities.errors import AgentRepositoryError
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
||||||
|
|
||||||
@@ -308,9 +309,7 @@ def test_cache_hitting():
|
|||||||
def handle_tool_end(source, event):
|
def handle_tool_end(source, event):
|
||||||
received_events.append(event)
|
received_events.append(event)
|
||||||
|
|
||||||
with (
|
with (patch.object(CacheHandler, "read") as read,):
|
||||||
patch.object(CacheHandler, "read") as read,
|
|
||||||
):
|
|
||||||
read.return_value = "0"
|
read.return_value = "0"
|
||||||
task = Task(
|
task = Task(
|
||||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||||
@@ -1040,7 +1039,7 @@ def test_agent_human_input():
|
|||||||
CrewAgentExecutor,
|
CrewAgentExecutor,
|
||||||
"_invoke_loop",
|
"_invoke_loop",
|
||||||
return_value=AgentFinish(output="Hello", thought="", text=""),
|
return_value=AgentFinish(output="Hello", thought="", text=""),
|
||||||
) as mock_invoke_loop,
|
),
|
||||||
):
|
):
|
||||||
# Execute the task
|
# Execute the task
|
||||||
output = agent.execute_task(task)
|
output = agent.execute_task(task)
|
||||||
@@ -2025,3 +2024,84 @@ def test_get_knowledge_search_query():
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository(mock_get_agent):
|
||||||
|
from crewai_tools import SerperDevTool
|
||||||
|
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 200
|
||||||
|
mock_get_response.json.return_value = {
|
||||||
|
"role": "test role",
|
||||||
|
"goal": "test goal",
|
||||||
|
"backstory": "test backstory",
|
||||||
|
"tools": ["SerperDevTool"],
|
||||||
|
}
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
agent = Agent(from_repository="test_agent")
|
||||||
|
|
||||||
|
assert agent.role == "test role"
|
||||||
|
assert agent.goal == "test goal"
|
||||||
|
assert agent.backstory == "test backstory"
|
||||||
|
assert len(agent.tools) == 1
|
||||||
|
assert isinstance(agent.tools[0], SerperDevTool)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_auth_token():
|
||||||
|
with patch("crewai.cli.command.get_auth_token", return_value="test_token"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
|
||||||
|
from crewai_tools import SerperDevTool
|
||||||
|
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 200
|
||||||
|
mock_get_response.json.return_value = {
|
||||||
|
"role": "test role",
|
||||||
|
"goal": "test goal",
|
||||||
|
"backstory": "test backstory",
|
||||||
|
"tools": ["SerperDevTool"],
|
||||||
|
}
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
agent = Agent(from_repository="test_agent", role="Custom Role")
|
||||||
|
|
||||||
|
assert agent.role == "Custom Role"
|
||||||
|
assert agent.goal == "test goal"
|
||||||
|
assert agent.backstory == "test backstory"
|
||||||
|
assert len(agent.tools) == 1
|
||||||
|
assert isinstance(agent.tools[0], SerperDevTool)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token):
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 200
|
||||||
|
mock_get_response.json.return_value = {
|
||||||
|
"role": "test role",
|
||||||
|
"goal": "test goal",
|
||||||
|
"backstory": "test backstory",
|
||||||
|
"tools": ["DoesNotExist"],
|
||||||
|
}
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
with pytest.raises(
|
||||||
|
AgentRepositoryError,
|
||||||
|
match="Tool DoesNotExist could not be loaded: module 'crewai_tools' has no attribute 'DoesNotExist'",
|
||||||
|
):
|
||||||
|
Agent(from_repository="test_agent")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token):
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 404
|
||||||
|
mock_get_response.text = "Agent not found"
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
with pytest.raises(
|
||||||
|
AgentRepositoryError,
|
||||||
|
match="Agent NOT_FOUND could not be loaded: Agent not found",
|
||||||
|
):
|
||||||
|
Agent(from_repository="NOT_FOUND")
|
||||||
|
|||||||
Reference in New Issue
Block a user