tools fix

This commit is contained in:
Lorenze Jay
2024-07-15 08:06:28 -07:00
parent f27c8e728d
commit d28ae857b7
3 changed files with 26 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent from langchain.agents.agent import RunnableAgent
from langchain.agents.tools import BaseTool from langchain.agents.tools import BaseTool
from langchain.agents.tools import tool as LangChainTool from langchain.agents.tools import tool as LangChainTool
from langchain.tools.base import StructuredTool
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@@ -298,7 +299,7 @@ class Agent(BaseAgent):
agent=RunnableAgent(runnable=inner_agent), **executor_args agent=RunnableAgent(runnable=inner_agent), **executor_args
) )
def get_delegation_tools(self, agents: List[BaseAgent]): def get_delegation_tools(self, agents: List[BaseAgent]) -> List[StructuredTool]:
agent_tools = AgentTools(agents=agents) agent_tools = AgentTools(agents=agents)
tools = agent_tools.tools() tools = agent_tools.tools()
return tools return tools

View File

@@ -180,7 +180,7 @@ class BaseAgent(ABC, BaseModel):
pass pass
@abstractmethod @abstractmethod
def get_delegation_tools(self, agents: List["BaseAgent"]): def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[Any]:
"""Set the task tools that init BaseAgenTools class.""" """Set the task tools that init BaseAgenTools class."""
pass pass

View File

@@ -654,8 +654,29 @@ class Crew(BaseModel):
def _add_delegation_tools(self, task: Task): def _add_delegation_tools(self, task: Task):
agents_for_delegation = [agent for agent in self.agents if agent != task.agent] agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and agents_for_delegation: if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
task.tools += task.agent.get_delegation_tools(agents_for_delegation) # type: ignore delegation_tools = task.agent.get_delegation_tools(agents_for_delegation)
# Add tools if they are not already in task.tools
for new_tool in delegation_tools:
# Find the index of the tool with the same name
existing_tool_index = next(
(
index
for index, tool in enumerate(task.tools or [])
if tool.name == new_tool.name
),
None,
)
if not task.tools:
task.tools = []
if existing_tool_index is not None:
# Replace the existing tool
task.tools[existing_tool_index] = new_tool
else:
# Add the new tool
task.tools.append(new_tool)
def _log_task_start(self, task: Task, agent: Optional[BaseAgent]): def _log_task_start(self, task: Task, agent: Optional[BaseAgent]):
color = self._logging_color color = self._logging_color