refactor: enhance tool handling in agent adapters

- Updated BaseToolAdapter to initialize original and converted tools in the constructor.
- Renamed method `all_tools` to `tools` for clarity in BaseToolAdapter.
- Added `sanitize_tool_name` method to ensure tool names are API compatible.
- Modified LangGraphAgentAdapter to utilize the updated tool handling and ensure proper tool configuration.
- Refactored LangGraphToolAdapter to streamline tool conversion and ensure consistent naming conventions.
This commit is contained in:
lorenzejay
2025-04-14 12:21:58 -07:00
parent 7579d91499
commit 5f4e645f10
3 changed files with 26 additions and 26 deletions

View File

@@ -12,11 +12,12 @@ class BaseToolAdapter(ABC):
different frameworks and platforms.
"""
original_tools: List[BaseTool] = []
converted_tools: List[Any] = []
original_tools: List[BaseTool]
converted_tools: List[Any]
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.tools = tools or []
self.original_tools = tools or []
self.converted_tools = []
@abstractmethod
def configure_tools(self, tools: List[BaseTool]) -> None:
@@ -27,6 +28,10 @@ class BaseToolAdapter(ABC):
"""
pass
def all_tools(self) -> List[Any]:
def tools(self) -> List[Any]:
"""Return all converted tools."""
return self.converted_tools
def sanitize_tool_name(self, tool_name: str) -> str:
"""Sanitize tool name for API compatibility."""
return tool_name.replace(" ", "_")

View File

@@ -71,6 +71,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
agent_config=agent_config,
**kwargs,
)
self.tools = tools or []
self._tool_adapter = LangGraphToolAdapter(tools=tools or [])
self._converter_adapter = LangGraphConverterAdapter(self)
self._max_iterations = max_iterations
@@ -79,14 +80,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
def _setup_graph(self) -> None:
"""Set up the LangGraph workflow graph."""
try:
# Initialize memory for the agent
self._memory = MemorySaver()
converted_tools = self._tool_adapter.converted_tools
converted_tools: List[Any] = self._tool_adapter.tools()
self._graph = create_react_agent(
model=self.llm,
tools=converted_tools,
tools=converted_tools or [],
checkpointer=self._memory,
debug=self.verbose,
)
@@ -142,13 +142,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
),
)
# Set up a session ID for this task
session_id = f"task_{id(task)}"
# Configure the invocation
config = {"configurable": {"thread_id": session_id}}
# Invoke the agent graph with the task prompt
result = self._graph.invoke(
{
"messages": [
@@ -159,7 +156,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
config,
)
# Get the final response
messages = result.get("messages", [])
last_message = messages[-1] if messages else None
@@ -169,7 +165,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
elif hasattr(last_message, "content"):
final_answer = getattr(last_message, "content", "")
# Post-process to ensure correct structured output format if needed
final_answer = (
self._converter_adapter.post_process_result(final_answer)
or "Task execution completed but no clear answer was provided."
@@ -256,18 +251,15 @@ class LangGraphAgentAdapter(BaseAgentAdapter):
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure the LangGraph agent for execution."""
if tools:
self.configure_tools(tools)
# No need for a separate executor in LangGraph
self.configure_tools(tools)
def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None:
"""Configure tools for the LangGraph agent."""
if tools:
all_tools = list(self.tools or []) + list(tools or [])
self._tool_adapter.configure_tools(all_tools)
# We need to recreate the graph with the new tools
self._setup_graph()
available_tools = self._tool_adapter.tools()
self._graph.tools = available_tools
def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]:
"""Implement delegation tools support for LangGraph."""

View File

@@ -9,6 +9,7 @@ class LangGraphToolAdapter(BaseToolAdapter):
def __init__(self, tools: Optional[List[BaseTool]] = None):
self.original_tools = tools or []
self.converted_tools = []
def configure_tools(self, tools: List[BaseTool]) -> None:
"""
@@ -17,16 +18,14 @@ class LangGraphToolAdapter(BaseToolAdapter):
"""
from langchain_core.tools import StructuredTool
self.tools = tools
self.converted_tools = []
converted_tools = []
if self.original_tools:
all_tools = tools + self.original_tools
else:
all_tools = tools
for tool in all_tools:
# Create a wrapper function that matches LangGraph's expected format
def tool_wrapper(*args, tool=tool, **kwargs):
# Extract inputs based on the tool's schema
if len(args) > 0 and isinstance(args[0], str):
return tool.run(args[0])
elif "input" in kwargs:
@@ -34,14 +33,18 @@ class LangGraphToolAdapter(BaseToolAdapter):
else:
return tool.run(**kwargs)
sanitized_tool_name = self.sanitize_tool_name(tool.name)
converted_tool = StructuredTool(
name=tool.name.replace(" ", "_"),
name=sanitized_tool_name,
description=tool.description,
func=tool_wrapper,
args_schema=tool.args_schema,
)
self.converted_tools.append(converted_tool)
converted_tools.append(converted_tool)
def all_tools(self) -> List[Any]:
return self.converted_tools
self.converted_tools = converted_tools
def tools(self) -> List[Any]:
return self.converted_tools or []