mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
It works!
This commit is contained in:
22
byoa_tools.py
Normal file
22
byoa_tools.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Initialize a ChatOpenAI model
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
# Create the agent with LangGraph
|
||||
memory = MemorySaver()
|
||||
agent_executor = create_react_agent(
|
||||
llm,
|
||||
tools,
|
||||
checkpointer=memory
|
||||
)
|
||||
|
||||
# Pass the LangGraph agent to the adapter
|
||||
wrapped_agent = LangChainAgentAdapter(
|
||||
langchain_agent=agent_executor,
|
||||
tools=tools,
|
||||
role="San Francisco Travel Advisor",
|
||||
goal="Curate a detailed list of the best neighborhoods to live in, restaurants to dine at, and attractions to visit in San Francisco.",
|
||||
backstory="An expert travel advisor with insider knowledge of San Francisco's vibrant culture, culinary delights, and hidden gems.",
|
||||
)
|
||||
@@ -69,6 +69,21 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
new_tools.append(Tool.from_langchain(tool))
|
||||
return new_tools
|
||||
|
||||
def _extract_text(self, message: Any) -> str:
|
||||
"""
|
||||
Helper to extract plain text from a message object.
|
||||
This checks if the message is a dict with a "content" key, or has a "content" attribute.
|
||||
"""
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
return message["content"]
|
||||
elif hasattr(message, "content") and isinstance(
|
||||
getattr(message, "content"), str
|
||||
):
|
||||
return getattr(message, "content")
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
return str(message)
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
@@ -146,55 +161,66 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
try:
|
||||
# Initial invocation of the LangChain agent
|
||||
result = self.agent_executor.invoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
"tool_names": getattr(self.agent_executor, "tools_names", ""),
|
||||
"tools": getattr(self.agent_executor, "tools_description", ""),
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
)["output"]
|
||||
# IMPORTANT: create an initial state using "messages" (not "input")
|
||||
init_state = {"messages": [("user", task_prompt)]}
|
||||
state = self.agent_executor.invoke(init_state)
|
||||
|
||||
# If human feedback is required, enter a feedback loop
|
||||
if task.human_input:
|
||||
result = self._handle_human_feedback(result)
|
||||
except Exception as e:
|
||||
# Example: you could add retry logic here if desired.
|
||||
raise e
|
||||
if "structured_response" in state:
|
||||
current_output = state["structured_response"]
|
||||
elif "messages" in state and state["messages"]:
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, tuple):
|
||||
current_output = last_message[1]
|
||||
else:
|
||||
current_output = self._extract_text(last_message)
|
||||
else:
|
||||
current_output = ""
|
||||
|
||||
return result
|
||||
# If human feedback is required, enter a feedback loop
|
||||
if task.human_input:
|
||||
current_output = self._handle_human_feedback(current_output)
|
||||
|
||||
return current_output
|
||||
|
||||
def _handle_human_feedback(self, current_output: str) -> str:
|
||||
"""
|
||||
Implements a feedback loop that prompts the user for feedback and then instructs
|
||||
the underlying LangChain agent to regenerate its answer with the requested changes.
|
||||
Only the inner content of the output is displayed to the user.
|
||||
"""
|
||||
while True:
|
||||
print("\nAgent output:")
|
||||
print(current_output)
|
||||
# Prompt the user for feedback
|
||||
# Print only the inner text extracted from current_output.
|
||||
print(self._extract_text(current_output))
|
||||
|
||||
feedback = input("\nEnter your feedback (or press Enter to accept): ")
|
||||
if not feedback.strip():
|
||||
break # No feedback provided, exit the loop
|
||||
|
||||
# Construct a new prompt with explicit instructions
|
||||
extracted_output = self._extract_text(current_output)
|
||||
new_prompt = (
|
||||
f"Below is your previous answer:\n{current_output}\n\n"
|
||||
f"Below is your previous answer:\n"
|
||||
f"{extracted_output}\n\n"
|
||||
f"Based on the following feedback: '{feedback}', please regenerate your answer with the requested details. "
|
||||
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\nUpdated answer:"
|
||||
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
||||
f"Updated answer:"
|
||||
)
|
||||
try:
|
||||
invocation = self.agent_executor.invoke(
|
||||
{
|
||||
"input": new_prompt,
|
||||
"tool_names": getattr(self.agent_executor, "tools_names", ""),
|
||||
"tools": getattr(self.agent_executor, "tools_description", ""),
|
||||
"ask_for_human_input": True,
|
||||
}
|
||||
# Use "messages" key for the prompt, like we do in execute_task.
|
||||
new_state = self.agent_executor.invoke(
|
||||
{"messages": [("user", new_prompt)]}
|
||||
)
|
||||
current_output = invocation["output"]
|
||||
if "structured_response" in new_state:
|
||||
new_output = new_state["structured_response"]
|
||||
elif "messages" in new_state and new_state["messages"]:
|
||||
last_message = new_state["messages"][-1]
|
||||
if isinstance(last_message, tuple):
|
||||
new_output = last_message[1]
|
||||
else:
|
||||
new_output = self._extract_text(last_message)
|
||||
else:
|
||||
new_output = ""
|
||||
current_output = new_output
|
||||
except Exception as e:
|
||||
print("Error during re-invocation with feedback:", e)
|
||||
break
|
||||
@@ -247,22 +273,21 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Creates an agent executor using LangChain's AgentExecutor.
|
||||
Creates an agent executor using LangGraph's create_react_agent if given an LLM,
|
||||
or uses the provided language model directly.
|
||||
"""
|
||||
try:
|
||||
from langchain.agents import AgentExecutor
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"LangChain library not found. Please run `uv add langchain` to add LangChain support."
|
||||
"LangGraph library not found. Please run `uv add langgraph` to add LangGraph support."
|
||||
) from e
|
||||
|
||||
# Use the following fallback strategy:
|
||||
# 1. If tools were passed in, use them.
|
||||
# 2. Otherwise, if self.tools exists, use them.
|
||||
# 3. Otherwise, try to extract the tools set in the underlying langchain agent.
|
||||
# Otherwise, create a new executor from the LLM.
|
||||
raw_tools = tools or self.tools
|
||||
|
||||
# Fallback: if raw_tools is empty, try to extract them from the wrapped langchain agent.
|
||||
if not raw_tools:
|
||||
# Try getting the tools from the agent's inner 'agent' attribute.
|
||||
if hasattr(self.langchain_agent, "agent") and hasattr(
|
||||
self.langchain_agent.agent, "tools"
|
||||
):
|
||||
@@ -270,10 +295,16 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
else:
|
||||
raw_tools = getattr(self.langchain_agent, "tools", [])
|
||||
|
||||
# Convert each CrewAI tool to a native LangChain tool if possible.
|
||||
used_tools = []
|
||||
try:
|
||||
# Import the CrewAI Tool class.
|
||||
from crewai.tools.base_tool import Tool as CrewTool
|
||||
except ImportError:
|
||||
CrewTool = None
|
||||
|
||||
for tool in raw_tools:
|
||||
if hasattr(tool, "to_langchain"):
|
||||
# Only attempt conversion if this is an instance of our CrewAI Tool.
|
||||
if CrewTool is not None and isinstance(tool, CrewTool):
|
||||
used_tools.append(tool.to_langchain())
|
||||
else:
|
||||
used_tools.append(tool)
|
||||
@@ -281,10 +312,17 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
print("Raw tools:", raw_tools)
|
||||
print("Used tools:", used_tools)
|
||||
|
||||
self.agent_executor = AgentExecutor.from_agent_and_tools(
|
||||
agent=self.langchain_agent,
|
||||
# Sanitize the agent's role for the "name" field. The allowed pattern is ^[a-zA-Z0-9_-]+$
|
||||
import re
|
||||
|
||||
agent_role = getattr(self, "role", "agent")
|
||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||
|
||||
self.agent_executor = create_react_agent(
|
||||
model=self.langchain_agent,
|
||||
tools=used_tools,
|
||||
verbose=getattr(self, "verbose", True),
|
||||
debug=getattr(self, "verbose", False),
|
||||
name=sanitized_role,
|
||||
)
|
||||
|
||||
def _parse_tools(self, tools: List[BaseTool]) -> List[BaseTool]:
|
||||
|
||||
@@ -181,33 +181,33 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return origin.__name__
|
||||
|
||||
def to_langchain(self) -> Any:
|
||||
"""
|
||||
Convert this CrewAI Tool instance into a LangChain-compatible tool.
|
||||
Returns a concrete subclass of LangChain's BaseTool.
|
||||
"""
|
||||
try:
|
||||
from langchain_core.tools import Tool as LC_Tool
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"LangChain library not found. Please run `uv add langchain` to add LangChain support."
|
||||
) from e
|
||||
# def to_langchain(self) -> Any:
|
||||
# """
|
||||
# Convert this CrewAI Tool instance into a LangChain-compatible tool.
|
||||
# Returns a concrete subclass of LangChain's BaseTool.
|
||||
# """
|
||||
# try:
|
||||
# from langchain_core.tools import Tool as LC_Tool
|
||||
# except ImportError as e:
|
||||
# raise ImportError(
|
||||
# "LangChain library not found. Please run `uv add langchain` to add LangChain support."
|
||||
# ) from e
|
||||
|
||||
# Capture the function in a local variable to avoid referencing None.
|
||||
tool_func = self.func
|
||||
# # Capture the function in a local variable to avoid referencing None.
|
||||
# tool_func = self.func
|
||||
|
||||
class ConcreteLangChainTool(LC_Tool):
|
||||
def _run(self, *args, **kwargs):
|
||||
return tool_func(*args, **kwargs)
|
||||
# class ConcreteLangChainTool(LC_Tool):
|
||||
# def _run(self, *args, **kwargs):
|
||||
# return tool_func(*args, **kwargs)
|
||||
|
||||
# Do not pass callback_manager; let LC_Tool use its default.
|
||||
print("Creating concrete langchain tool")
|
||||
return ConcreteLangChainTool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
func=tool_func,
|
||||
args_schema=self.args_schema,
|
||||
)
|
||||
# # Do not pass callback_manager; let LC_Tool use its default.
|
||||
# print("Creating concrete langchain tool")
|
||||
# return ConcreteLangChainTool(
|
||||
# name=self.name,
|
||||
# description=self.description,
|
||||
# func=self._run,
|
||||
# args_schema=self.args_schema,
|
||||
# )
|
||||
|
||||
@property
|
||||
def get(self) -> Callable[[str, Any], Any]:
|
||||
@@ -227,56 +227,21 @@ class Tool(BaseTool):
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> "Tool":
|
||||
"""Convert a LangChain tool to a CrewAI Tool."""
|
||||
# Handle missing args_schema
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
if args_schema is None:
|
||||
# Create default args schema
|
||||
args_schema = create_model(f"{tool.name}Input", __base__=PydanticBaseModel)
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"func": tool._run, # LangChain tools use _run
|
||||
"args_schema": args_schema,
|
||||
}
|
||||
# Create and validate a new instance directly from the dictionary
|
||||
return cls.model_validate(tool_dict)
|
||||
|
||||
def to_langchain(self) -> Any:
|
||||
"""Convert to LangChain tool with proper get method."""
|
||||
"""Convert to a LangChain-compatible tool."""
|
||||
try:
|
||||
from langchain_core.tools import Tool as LC_Tool
|
||||
except ImportError:
|
||||
raise ImportError("langchain_core is not installed")
|
||||
|
||||
LC_Tool(
|
||||
# Use self._run (which is bound and calls self.func) so that the LC_Tool gets proper attributes.
|
||||
return LC_Tool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
func=self.func,
|
||||
func=self._run,
|
||||
args_schema=self.args_schema,
|
||||
)
|
||||
|
||||
# # Create subclass with get method
|
||||
# class PatchedTool(LC_Tool):
|
||||
# def get(self, key: str, default: Any = None) -> Any:
|
||||
# return getattr(self, key, default)
|
||||
|
||||
# return PatchedTool(
|
||||
# name=self.name,
|
||||
# description=self.description,
|
||||
# func=self.func,
|
||||
# args_schema=self.args_schema,
|
||||
# callback_manager=None,
|
||||
# )
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
"""
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
Reference in New Issue
Block a user