tools fix

This commit is contained in:
Lorenze Jay
2024-07-15 08:06:28 -07:00
parent f27c8e728d
commit d28ae857b7
3 changed files with 26 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent
from langchain.agents.tools import BaseTool
from langchain.agents.tools import tool as LangChainTool
from langchain.tools.base import StructuredTool
from langchain_core.agents import AgentAction
from langchain_core.callbacks import BaseCallbackHandler
from langchain_openai import ChatOpenAI
@@ -298,7 +299,7 @@ class Agent(BaseAgent):
agent=RunnableAgent(runnable=inner_agent), **executor_args
)
def get_delegation_tools(self, agents: List[BaseAgent]):
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[StructuredTool]:
agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools()
return tools

View File

@@ -180,7 +180,7 @@ class BaseAgent(ABC, BaseModel):
pass
@abstractmethod
def get_delegation_tools(self, agents: List["BaseAgent"]):
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[Any]:
"""Set the task tools that init BaseAgenTools class."""
pass

View File

@@ -654,8 +654,29 @@ class Crew(BaseModel):
def _add_delegation_tools(self, task: Task):
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and agents_for_delegation:
task.tools += task.agent.get_delegation_tools(agents_for_delegation) # type: ignore
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
delegation_tools = task.agent.get_delegation_tools(agents_for_delegation)
# Add tools if they are not already in task.tools
for new_tool in delegation_tools:
# Find the index of the tool with the same name
existing_tool_index = next(
(
index
for index, tool in enumerate(task.tools or [])
if tool.name == new_tool.name
),
None,
)
if not task.tools:
task.tools = []
if existing_tool_index is not None:
# Replace the existing tool
task.tools[existing_tool_index] = new_tool
else:
# Add the new tool
task.tools.append(new_tool)
def _log_task_start(self, task: Task, agent: Optional[BaseAgent]):
color = self._logging_color