mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Merge branch 'main' into lg-support-set-task-context
This commit is contained in:
@@ -20,6 +20,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
|
|||||||
from crewai.utilities import Converter, Prompts
|
from crewai.utilities import Converter, Prompts
|
||||||
from crewai.utilities.agent_utils import (
|
from crewai.utilities.agent_utils import (
|
||||||
get_tool_names,
|
get_tool_names,
|
||||||
|
load_agent_from_repository,
|
||||||
parse_tools,
|
parse_tools,
|
||||||
render_text_description_and_args,
|
render_text_description_and_args,
|
||||||
)
|
)
|
||||||
@@ -134,6 +135,16 @@ class Agent(BaseAgent):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||||
)
|
)
|
||||||
|
from_repository: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Agent's role to be used from your repository.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def validate_from_repository(cls, v):
|
||||||
|
if v is not None and (from_repository := v.get("from_repository")):
|
||||||
|
return load_agent_from_repository(from_repository) | v
|
||||||
|
return v
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
|
|||||||
@@ -5,5 +5,5 @@ def get_auth_token() -> str:
|
|||||||
"""Get the authentication token."""
|
"""Get the authentication token."""
|
||||||
access_token = TokenManager().get_token()
|
access_token = TokenManager().get_token()
|
||||||
if not access_token:
|
if not access_token:
|
||||||
raise Exception()
|
raise Exception("No token found, make sure you are logged in")
|
||||||
return access_token
|
return access_token
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class PlusAPI:
|
|||||||
|
|
||||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||||
|
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||||
|
|
||||||
def __init__(self, api_key: str) -> None:
|
def __init__(self, api_key: str) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -37,6 +38,9 @@ class PlusAPI:
|
|||||||
def get_tool(self, handle: str):
|
def get_tool(self, handle: str):
|
||||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||||
|
|
||||||
|
def get_agent(self, handle: str):
|
||||||
|
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||||
|
|
||||||
def publish_tool(
|
def publish_tool(
|
||||||
self,
|
self,
|
||||||
handle: str,
|
handle: str,
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, redirect_stderr, redirect_stdout
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
DefaultDict,
|
DefaultDict,
|
||||||
@@ -31,7 +30,6 @@ from crewai.utilities.events.llm_events import (
|
|||||||
LLMCallType,
|
LLMCallType,
|
||||||
LLMStreamChunkEvent,
|
LLMStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
@@ -45,6 +43,9 @@ with warnings.catch_warnings():
|
|||||||
from litellm.utils import supports_response_schema
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
|
||||||
|
import io
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
@@ -54,12 +55,17 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class FilteredStream:
|
class FilteredStream(io.TextIOBase):
|
||||||
def __init__(self, original_stream):
|
_lock = None
|
||||||
|
|
||||||
|
def __init__(self, original_stream: TextIO):
|
||||||
self._original_stream = original_stream
|
self._original_stream = original_stream
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def write(self, s) -> int:
|
def write(self, s: str) -> int:
|
||||||
|
if not self._lock:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Filter out extraneous messages from LiteLLM
|
# Filter out extraneous messages from LiteLLM
|
||||||
if (
|
if (
|
||||||
@@ -214,15 +220,11 @@ def suppress_warnings():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Redirect stdout and stderr
|
# Redirect stdout and stderr
|
||||||
old_stdout = sys.stdout
|
with (
|
||||||
old_stderr = sys.stderr
|
redirect_stdout(FilteredStream(sys.stdout)),
|
||||||
sys.stdout = FilteredStream(old_stdout)
|
redirect_stderr(FilteredStream(sys.stderr)),
|
||||||
sys.stderr = FilteredStream(old_stderr)
|
):
|
||||||
try:
|
|
||||||
yield
|
yield
|
||||||
finally:
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
sys.stderr = old_stderr
|
|
||||||
|
|
||||||
|
|
||||||
class Delta(TypedDict):
|
class Delta(TypedDict):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from crewai.tools.base_tool import BaseTool
|
|||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.tools.tool_types import ToolResult
|
from crewai.tools.tool_types import ToolResult
|
||||||
from crewai.utilities import I18N, Printer
|
from crewai.utilities import I18N, Printer
|
||||||
|
from crewai.utilities.errors import AgentRepositoryError
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
@@ -428,3 +429,36 @@ def show_agent_logs(
|
|||||||
printer.print(
|
printer.print(
|
||||||
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_agent_from_repository(from_repository: str) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
if from_repository:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from crewai.cli.authentication.token import get_auth_token
|
||||||
|
from crewai.cli.plus_api import PlusAPI
|
||||||
|
|
||||||
|
client = PlusAPI(api_key=get_auth_token())
|
||||||
|
response = client.get_agent(from_repository)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise AgentRepositoryError(
|
||||||
|
f"Agent {from_repository} could not be loaded: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = response.json()
|
||||||
|
for key, value in agent.items():
|
||||||
|
if key == "tools":
|
||||||
|
attributes[key] = []
|
||||||
|
for tool_name in value:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module("crewai_tools")
|
||||||
|
tool_class = getattr(module, tool_name)
|
||||||
|
attributes[key].append(tool_class())
|
||||||
|
except Exception as e:
|
||||||
|
raise AgentRepositoryError(
|
||||||
|
f"Tool {tool_name} could not be loaded: {e}"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
attributes[key] = value
|
||||||
|
return attributes
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Error message definitions for CrewAI database operations."""
|
"""Error message definitions for CrewAI database operations."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -37,3 +38,9 @@ class DatabaseError:
|
|||||||
The formatted error message
|
The formatted error message
|
||||||
"""
|
"""
|
||||||
return template.format(str(error))
|
return template.format(str(error))
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRepositoryError(Exception):
|
||||||
|
"""Exception raised when an agent repository is not found."""
|
||||||
|
|
||||||
|
...
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -18,6 +18,7 @@ from crewai.tools import tool
|
|||||||
from crewai.tools.tool_calling import InstructorToolCalling
|
from crewai.tools.tool_calling import InstructorToolCalling
|
||||||
from crewai.tools.tool_usage import ToolUsage
|
from crewai.tools.tool_usage import ToolUsage
|
||||||
from crewai.utilities import RPMController
|
from crewai.utilities import RPMController
|
||||||
|
from crewai.utilities.errors import AgentRepositoryError
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
||||||
|
|
||||||
@@ -308,9 +309,7 @@ def test_cache_hitting():
|
|||||||
def handle_tool_end(source, event):
|
def handle_tool_end(source, event):
|
||||||
received_events.append(event)
|
received_events.append(event)
|
||||||
|
|
||||||
with (
|
with (patch.object(CacheHandler, "read") as read,):
|
||||||
patch.object(CacheHandler, "read") as read,
|
|
||||||
):
|
|
||||||
read.return_value = "0"
|
read.return_value = "0"
|
||||||
task = Task(
|
task = Task(
|
||||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||||
@@ -1040,7 +1039,7 @@ def test_agent_human_input():
|
|||||||
CrewAgentExecutor,
|
CrewAgentExecutor,
|
||||||
"_invoke_loop",
|
"_invoke_loop",
|
||||||
return_value=AgentFinish(output="Hello", thought="", text=""),
|
return_value=AgentFinish(output="Hello", thought="", text=""),
|
||||||
) as mock_invoke_loop,
|
),
|
||||||
):
|
):
|
||||||
# Execute the task
|
# Execute the task
|
||||||
output = agent.execute_task(task)
|
output = agent.execute_task(task)
|
||||||
@@ -2025,3 +2024,86 @@ def test_get_knowledge_search_query():
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_get_auth_token():
|
||||||
|
with patch(
|
||||||
|
"crewai.cli.authentication.token.get_auth_token", return_value="test_token"
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||||
|
from crewai_tools import SerperDevTool
|
||||||
|
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 200
|
||||||
|
mock_get_response.json.return_value = {
|
||||||
|
"role": "test role",
|
||||||
|
"goal": "test goal",
|
||||||
|
"backstory": "test backstory",
|
||||||
|
"tools": ["SerperDevTool"],
|
||||||
|
}
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
agent = Agent(from_repository="test_agent")
|
||||||
|
|
||||||
|
assert agent.role == "test role"
|
||||||
|
assert agent.goal == "test goal"
|
||||||
|
assert agent.backstory == "test backstory"
|
||||||
|
assert len(agent.tools) == 1
|
||||||
|
assert isinstance(agent.tools[0], SerperDevTool)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
|
||||||
|
from crewai_tools import SerperDevTool
|
||||||
|
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 200
|
||||||
|
mock_get_response.json.return_value = {
|
||||||
|
"role": "test role",
|
||||||
|
"goal": "test goal",
|
||||||
|
"backstory": "test backstory",
|
||||||
|
"tools": ["SerperDevTool"],
|
||||||
|
}
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
agent = Agent(from_repository="test_agent", role="Custom Role")
|
||||||
|
|
||||||
|
assert agent.role == "Custom Role"
|
||||||
|
assert agent.goal == "test goal"
|
||||||
|
assert agent.backstory == "test backstory"
|
||||||
|
assert len(agent.tools) == 1
|
||||||
|
assert isinstance(agent.tools[0], SerperDevTool)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token):
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 200
|
||||||
|
mock_get_response.json.return_value = {
|
||||||
|
"role": "test role",
|
||||||
|
"goal": "test goal",
|
||||||
|
"backstory": "test backstory",
|
||||||
|
"tools": ["DoesNotExist"],
|
||||||
|
}
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
with pytest.raises(
|
||||||
|
AgentRepositoryError,
|
||||||
|
match="Tool DoesNotExist could not be loaded: module 'crewai_tools' has no attribute 'DoesNotExist'",
|
||||||
|
):
|
||||||
|
Agent(from_repository="test_agent")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||||
|
def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token):
|
||||||
|
mock_get_response = MagicMock()
|
||||||
|
mock_get_response.status_code = 404
|
||||||
|
mock_get_response.text = "Agent not found"
|
||||||
|
mock_get_agent.return_value = mock_get_response
|
||||||
|
with pytest.raises(
|
||||||
|
AgentRepositoryError,
|
||||||
|
match="Agent NOT_FOUND could not be loaded: Agent not found",
|
||||||
|
):
|
||||||
|
Agent(from_repository="NOT_FOUND")
|
||||||
|
|||||||
Reference in New Issue
Block a user