From 5f4e645f1060a0c3c26162de104f8ae949a00506 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Mon, 14 Apr 2025 12:21:58 -0700 Subject: [PATCH] refactor: enhance tool handling in agent adapters - Updated BaseToolAdapter to initialize original and converted tools in the constructor. - Renamed method `all_tools` to `tools` for clarity in BaseToolAdapter. - Added `sanitize_tool_name` method to ensure tool names are API compatible. - Modified LangGraphAgentAdapter to utilize the updated tool handling and ensure proper tool configuration. - Refactored LangGraphToolAdapter to streamline tool conversion and ensure consistent naming conventions. --- .../agent_adapters/base_tool_adapter.py | 13 ++++++++---- .../langgraph/langgraph_adapter.py | 20 ++++++------------- .../langgraph/langgraph_tool_adapter.py | 19 ++++++++++-------- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/crewai/agents/agent_adapters/base_tool_adapter.py b/src/crewai/agents/agent_adapters/base_tool_adapter.py index 001df2e38..f1ee438a8 100644 --- a/src/crewai/agents/agent_adapters/base_tool_adapter.py +++ b/src/crewai/agents/agent_adapters/base_tool_adapter.py @@ -12,11 +12,12 @@ class BaseToolAdapter(ABC): different frameworks and platforms. """ - original_tools: List[BaseTool] = [] - converted_tools: List[Any] = [] + original_tools: List[BaseTool] + converted_tools: List[Any] def __init__(self, tools: Optional[List[BaseTool]] = None): - self.tools = tools or [] + self.original_tools = tools or [] + self.converted_tools = [] @abstractmethod def configure_tools(self, tools: List[BaseTool]) -> None: @@ -27,6 +28,10 @@ class BaseToolAdapter(ABC): """ pass - def all_tools(self) -> List[Any]: + def tools(self) -> List[Any]: """Return all converted tools.""" return self.converted_tools + + def sanitize_tool_name(self, tool_name: str) -> str: + """Sanitize tool name for API compatibility.""" + return tool_name.replace(" ", "_") diff --git a/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py b/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py index 9ab6599bd..49ca96aea 100644 --- a/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py +++ b/src/crewai/agents/agent_adapters/langgraph/langgraph_adapter.py @@ -71,6 +71,7 @@ class LangGraphAgentAdapter(BaseAgentAdapter): agent_config=agent_config, **kwargs, ) + self.tools = tools or [] self._tool_adapter = LangGraphToolAdapter(tools=tools or []) self._converter_adapter = LangGraphConverterAdapter(self) self._max_iterations = max_iterations @@ -79,14 +80,13 @@ class LangGraphAgentAdapter(BaseAgentAdapter): def _setup_graph(self) -> None: """Set up the LangGraph workflow graph.""" try: - # Initialize memory for the agent self._memory = MemorySaver() - converted_tools = self._tool_adapter.converted_tools + converted_tools: List[Any] = self._tool_adapter.tools() self._graph = create_react_agent( model=self.llm, - tools=converted_tools, + tools=converted_tools or [], checkpointer=self._memory, debug=self.verbose, ) @@ -142,13 +142,10 @@ class LangGraphAgentAdapter(BaseAgentAdapter): ), ) - # Set up a session ID for this task session_id = f"task_{id(task)}" - # Configure the invocation config = {"configurable": {"thread_id": session_id}} - # Invoke the agent graph with the task prompt result = self._graph.invoke( { "messages": [ @@ -159,7 +156,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter): config, ) - # Get the final response messages = result.get("messages", []) last_message = messages[-1] if messages else None @@ -169,7 +165,6 @@ class LangGraphAgentAdapter(BaseAgentAdapter): elif hasattr(last_message, "content"): final_answer = getattr(last_message, "content", "") - # Post-process to ensure correct structured output format if needed final_answer = ( self._converter_adapter.post_process_result(final_answer) or "Task execution completed but no clear answer was provided." @@ -256,18 +251,15 @@ class LangGraphAgentAdapter(BaseAgentAdapter): def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None: """Configure the LangGraph agent for execution.""" - if tools: - self.configure_tools(tools) - - # No need for a separate executor in LangGraph + self.configure_tools(tools) def configure_tools(self, tools: Optional[List[BaseTool]] = None) -> None: """Configure tools for the LangGraph agent.""" if tools: all_tools = list(self.tools or []) + list(tools or []) self._tool_adapter.configure_tools(all_tools) - # We need to recreate the graph with the new tools - self._setup_graph() + available_tools = self._tool_adapter.tools() + self._graph.tools = available_tools def get_delegation_tools(self, agents: List[BaseAgent]) -> List[BaseTool]: """Implement delegation tools support for LangGraph.""" diff --git a/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py b/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py index cdbf4bc5c..37c8e93e4 100644 --- a/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py +++ b/src/crewai/agents/agent_adapters/langgraph/langgraph_tool_adapter.py @@ -9,6 +9,7 @@ class LangGraphToolAdapter(BaseToolAdapter): def __init__(self, tools: Optional[List[BaseTool]] = None): self.original_tools = tools or [] + self.converted_tools = [] def configure_tools(self, tools: List[BaseTool]) -> None: """ @@ -17,16 +18,14 @@ class LangGraphToolAdapter(BaseToolAdapter): """ from langchain_core.tools import StructuredTool - self.tools = tools - self.converted_tools = [] + converted_tools = [] if self.original_tools: all_tools = tools + self.original_tools else: all_tools = tools for tool in all_tools: - # Create a wrapper function that matches LangGraph's expected format + def tool_wrapper(*args, tool=tool, **kwargs): - # Extract inputs based on the tool's schema if len(args) > 0 and isinstance(args[0], str): return tool.run(args[0]) elif "input" in kwargs: @@ -34,14 +33,18 @@ class LangGraphToolAdapter(BaseToolAdapter): else: return tool.run(**kwargs) + sanitized_tool_name = self.sanitize_tool_name(tool.name) + converted_tool = StructuredTool( - name=tool.name.replace(" ", "_"), + name=sanitized_tool_name, description=tool.description, func=tool_wrapper, args_schema=tool.args_schema, ) - self.converted_tools.append(converted_tool) + converted_tools.append(converted_tool) - def all_tools(self) -> List[Any]: - return self.converted_tools + self.converted_tools = converted_tools + + def tools(self) -> List[Any]: + return self.converted_tools or []