adding function calling llm support

This commit is contained in:
João Moura
2024-02-13 02:57:12 -08:00
parent 2410d0c531
commit 55c0c186d1
11 changed files with 4378 additions and 14 deletions

View File

@@ -36,6 +36,7 @@ class Agent(BaseModel):
goal: The objective of the agent.
backstory: The backstory of the agent.
llm: The language model that will run the agent.
function_calling_llm: The language model that will the tool calling for this agent, it overrides the crew function_calling_llm.
max_iter: Maximum number of iterations for an agent to execute a task.
memory: Whether the agent should have memory or not.
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
@@ -98,6 +99,9 @@ class Agent(BaseModel):
),
description="Language model that will run the agent.",
)
function_calling_llm: Optional[Any] = Field(
description="Language model that will run the agent.", default=None
)
@field_validator("id", mode="before")
@classmethod
@@ -140,7 +144,6 @@ class Agent(BaseModel):
Returns:
Output of the agent
"""
task_prompt = task.prompt()
if context:
@@ -151,7 +154,7 @@ class Agent(BaseModel):
tools = tools or self.tools
self.agent_executor.tools = tools
self.agent_executor.task = task
self.agent_executor.tools_description = (render_text_description(tools),)
self.agent_executor.tools_description = render_text_description(tools)
self.agent_executor.tools_names = self.__tools_names(tools)
result = self.agent_executor.invoke(
@@ -208,6 +211,7 @@ class Agent(BaseModel):
"max_iterations": self.max_iter,
"step_callback": self.step_callback,
"tools_handler": self.tools_handler,
"function_calling_llm": self.function_calling_llm,
}
if self._rpm_controller:

View File

@@ -24,6 +24,7 @@ class CrewAgentExecutor(AgentExecutor):
task: Any = None
tools_description: str = ""
tools_names: str = ""
function_calling_llm: Any = None
request_within_rpm_limit: Any = None
tools_handler: InstanceOf[ToolsHandler] = None
max_iterations: Optional[int] = 15
@@ -194,6 +195,7 @@ class CrewAgentExecutor(AgentExecutor):
tools=self.tools,
tools_description=self.tools_description,
tools_names=self.tools_names,
function_calling_llm=self.function_calling_llm,
llm=self.llm,
task=self.task,
).use(agent_action.log)

View File

@@ -22,9 +22,9 @@ class ToolsHandler:
def on_tool_end(self, calling: ToolCalling, output: str) -> Any:
"""Run when tool ends running."""
if self.last_used_tool.function_name != CacheTools().name:
if self.last_used_tool.tool_name != CacheTools().name:
self.cache.add(
tool=calling.function_name,
tool=calling.tool_name,
input=calling.arguments,
output=output,
)

View File

@@ -32,6 +32,7 @@ class Crew(BaseModel):
tasks: List of tasks assigned to the crew.
agents: List of agents part of this crew.
manager_llm: The language model that will run manager agent.
function_calling_llm: The language model that will run the tool calling for all the agents.
process: The process flow that the crew will follow (e.g., sequential).
verbose: Indicates the verbosity level for logging during execution.
config: Configuration settings for the crew.
@@ -60,6 +61,9 @@ class Crew(BaseModel):
manager_llm: Optional[Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Any] = Field(
description="Language model that will run the agent.", default=None
)
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: Optional[bool] = Field(default=False)
@@ -176,7 +180,11 @@ class Crew(BaseModel):
for agent in self.agents:
agent.i18n = I18N(language=self.language)
if (self.step_callback) and (not agent.step_callback):
if not agent.function_calling_llm:
agent.function_calling_llm = self.function_calling_llm
agent.create_agent_executor()
if not agent.step_callback:
agent.step_callback = self.step_callback
agent.create_agent_executor()

View File

@@ -41,6 +41,7 @@ class ToolUsage:
tools_names: str,
task: Any,
llm: Any,
function_calling_llm: Any,
) -> None:
self._i18n: I18N = I18N()
self._printer: Printer = Printer()
@@ -54,6 +55,7 @@ class ToolUsage:
self.tools = tools
self.task = task
self.llm = llm
self.function_calling_llm = function_calling_llm
def use(self, tool_string: str):
calling = self._tool_calling(tool_string)
@@ -79,7 +81,9 @@ class ToolUsage:
try:
result = self._i18n.errors("task_repeated_usage").format(
tool=calling.tool_name,
tool_input=", ".join(calling.arguments.values()),
tool_input=", ".join(
[str(arg) for arg in calling.arguments.values()]
),
)
self._printer.print(content=f"\n\n{result}\n", color="yellow")
self._telemetry.tool_repeated_usage(
@@ -138,7 +142,9 @@ class ToolUsage:
self, calling: Union[ToolCalling, InstructorToolCalling]
) -> None:
if last_tool_usage := self.tools_handler.last_used_tool:
return calling == last_tool_usage
return (calling.tool_name == last_tool_usage.tool_name) and (
calling.arguments == last_tool_usage.arguments
)
def _select_tool(self, tool_name: str) -> BaseTool:
for tool in self.tools:
@@ -175,15 +181,18 @@ class ToolUsage:
tool_string = tool_string.replace("Action:", "Tool Name:")
tool_string = tool_string.replace("Action Input:", "Tool Arguments:")
if (isinstance(self.llm, ChatOpenAI)) and (
self.llm.openai_api_base == None
):
llm = self.function_calling_llm or self.llm
if (isinstance(llm, ChatOpenAI)) and (llm.openai_api_base == None):
print("CARALHOooooooooooo")
print(llm)
print("CARALHOooooooooooo")
client = instructor.patch(
self.llm.client._client,
llm.client._client,
mode=instructor.Mode.FUNCTIONS,
)
calling = client.chat.completions.create(
model=self.llm.model_name,
model=llm.model_name,
messages=[
{
"role": "system",
@@ -220,13 +229,13 @@ class ToolUsage:
""",
},
)
chain = prompt | self.llm | parser
chain = prompt | llm | parser
calling = chain.invoke({"tool_string": tool_string})
except Exception:
self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts:
self._telemetry.tool_usage_error(llm=self.llm)
self._telemetry.tool_usage_error(llm=llm)
return ToolUsageErrorException(self._i18n.errors("tool_usage_error"))
return self._tool_calling(tool_string)