mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 09:12:39 +00:00
refactor: enhance tool handling in agent adapters
- Updated BaseToolAdapter to initialize original and converted tools in the constructor. - Renamed method `all_tools` to `tools` for clarity in BaseToolAdapter. - Added `sanitize_tool_name` method to ensure tool names are API compatible. - Modified LangGraphAgentAdapter to utilize the updated tool handling and ensure proper tool configuration. - Refactored LangGraphToolAdapter to streamline tool conversion and ensure consistent naming conventions.
This commit is contained in:
@@ -12,11 +12,12 @@ class BaseToolAdapter(ABC):
|
||||
different frameworks and platforms.
|
||||
"""
|
||||
|
||||
original_tools: List[BaseTool] = []
|
||||
converted_tools: List[Any] = []
|
||||
original_tools: List[BaseTool]
|
||||
converted_tools: List[Any]
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
self.tools = tools or []
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
@abstractmethod
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
@@ -27,6 +28,10 @@ class BaseToolAdapter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def all_tools(self) -> List[Any]:
|
||||
def tools(self) -> List[Any]:
|
||||
"""Return all converted tools."""
|
||||
return self.converted_tools
|
||||
|
||||
def sanitize_tool_name(self, tool_name: str) -> str:
|
||||
"""Sanitize tool name for API compatibility."""
|
||||
return tool_name.replace(" ", "_")
|
||||
|
||||
@@ -71,6 +71,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
agent_config=agent_config,
|
||||
**kwargs,
|
||||
)
|
||||
self.tools = tools or []
|
||||
self._tool_adapter = LangGraphToolAdapter(tools=tools or [])
|
||||
self._converter_adapter = LangGraphConverterAdapter(self)
|
||||
self._max_iterations = max_iterations
|
||||
@@ -79,14 +80,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
def _setup_graph(self) -> None:
|
||||
"""Set up the LangGraph workflow graph."""
|
||||
try:
|
||||
# Initialize memory for the agent
|
||||
self._memory = MemorySaver()
|
||||
|
||||
converted_tools = self._tool_adapter.converted_tools
|
||||
converted_tools: List[Any] = self._tool_adapter.tools()
|
||||
|
||||
self._graph = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=converted_tools,
|
||||
tools=converted_tools or [],
|
||||
checkpointer=self._memory,
|
||||
debug=self.verbose,
|
||||
)
|
||||
@@ -142,13 +142,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
),
|
||||
)
|
||||
|
||||
# Set up a session ID for this task
|
||||
session_id = f"task_{id(task)}"
|
||||
|
||||
# Configure the invocation
|
||||
config = {"configurable": {"thread_id": session_id}}
|
||||
|
||||
# Invoke the agent graph with the task prompt
|
||||
result = self._graph.invoke(
|
||||
{
|
||||
"messages": [
|
||||
@@ -159,7 +156,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
config,
|
||||
)
|
||||
|
||||
# Get the final response
|
||||
messages = result.get("messages", [])
|
||||
last_message = messages[-1] if messages else None
|
||||
|
||||
@@ -169,7 +165,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
elif hasattr(last_message, "content"):
|
||||
final_answer = getattr(last_message, "content", "")
|
||||
|
||||
# Post-process to ensure correct structured output format if needed
|
||||
final_answer = (
|
||||
self._converter_adapter.post_process_result(final_answer)
|
||||
or "Task execution completed but no clear answer was provided."
|
||||
@@ -256,18 +251,15 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure the LangGraph agent for execution."""
|
||||
if tools:
|
||||
self.configure_tools(tools)
|
||||
|
||||
# No need for a separate executor in LangGraph
|
||||
self.configure_tools(tools)
|
||||
|
||||
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""Configure tools for the LangGraph agent."""
|
||||
if tools:
|
||||
all_tools = list(self.tools or []) + list(tools or [])
|
||||
self._tool_adapter.configure_tools(all_tools)
|
||||
# We need to recreate the graph with the new tools
|
||||
self._setup_graph()
|
||||
available_tools = self._tool_adapter.tools()
|
||||
self._graph.tools = available_tools
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
|
||||
"""Implement delegation tools support for LangGraph."""
|
||||
|
||||
@@ -9,6 +9,7 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
|
||||
def __init__(self, tools: Optional[List[BaseTool]] = None):
|
||||
self.original_tools = tools or []
|
||||
self.converted_tools = []
|
||||
|
||||
def configure_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""
|
||||
@@ -17,16 +18,14 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
self.tools = tools
|
||||
self.converted_tools = []
|
||||
converted_tools = []
|
||||
if self.original_tools:
|
||||
all_tools = tools + self.original_tools
|
||||
else:
|
||||
all_tools = tools
|
||||
for tool in all_tools:
|
||||
# Create a wrapper function that matches LangGraph's expected format
|
||||
|
||||
def tool_wrapper(*args, tool=tool, **kwargs):
|
||||
# Extract inputs based on the tool's schema
|
||||
if len(args) > 0 and isinstance(args[0], str):
|
||||
return tool.run(args[0])
|
||||
elif "input" in kwargs:
|
||||
@@ -34,14 +33,18 @@ class LangGraphToolAdapter(BaseToolAdapter):
|
||||
else:
|
||||
return tool.run(**kwargs)
|
||||
|
||||
sanitized_tool_name = self.sanitize_tool_name(tool.name)
|
||||
|
||||
converted_tool = StructuredTool(
|
||||
name=tool.name.replace(" ", "_"),
|
||||
name=sanitized_tool_name,
|
||||
description=tool.description,
|
||||
func=tool_wrapper,
|
||||
args_schema=tool.args_schema,
|
||||
)
|
||||
|
||||
self.converted_tools.append(converted_tool)
|
||||
converted_tools.append(converted_tool)
|
||||
|
||||
def all_tools(self) -> List[Any]:
|
||||
return self.converted_tools
|
||||
self.converted_tools = converted_tools
|
||||
|
||||
def tools(self) -> List[Any]:
|
||||
return self.converted_tools or []
|
||||
|
||||
Reference in New Issue
Block a user