Adding new LLM class

This commit is contained in:
João Moura
2024-09-23 03:59:05 -03:00
parent 59e51f18fd
commit a19a4a5556
9 changed files with 124 additions and 93 deletions

View File

@@ -72,7 +72,8 @@ class ToolUsage:
# Set the maximum parsing attempts for bigger models
if (
self._is_gpt(self.function_calling_llm)
self.function_calling_llm
and self._is_gpt(self.function_calling_llm)
and self.function_calling_llm in OPENAI_BIGGER_MODELS
):
self._max_parsing_attempts = 2
@@ -85,6 +86,7 @@ class ToolUsage:
def use(
self, calling: Union[ToolCalling, InstructorToolCalling], tool_string: str
) -> str:
print("calling", calling)
if isinstance(calling, ToolUsageErrorException):
error = calling.message
if self.agent.verbose:
@@ -299,9 +301,9 @@ class ToolUsage:
def _is_gpt(self, llm) -> bool:
return (
"gpt" in str(llm).lower()
or "o1-preview" in str(llm).lower()
or "o1-mini" in str(llm).lower()
"gpt" in str(llm.model).lower()
or "o1-preview" in str(llm.model).lower()
or "o1-mini" in str(llm.model).lower()
)
def _tool_calling(
@@ -309,11 +311,16 @@ class ToolUsage:
) -> Union[ToolCalling, InstructorToolCalling]:
try:
if self.function_calling_llm:
print("self.function_calling_llm")
model = (
InstructorToolCalling
if self._is_gpt(self.function_calling_llm)
else ToolCalling
)
print("model", model)
print(
"self.function_calling_llm.model", self.function_calling_llm.model
)
converter = Converter(
text=f"Only tools available:\n###\n{self._render()}\n\nReturn a valid schema for the tool, the tool name must be exactly equal one of the options, use this text to inform the valid output schema:\n\n### TEXT \n{tool_string}",
llm=self.function_calling_llm,
@@ -329,7 +336,15 @@ class ToolUsage:
),
max_attempts=1,
)
calling = converter.to_pydantic()
print("converter", converter)
tool_object = converter.to_pydantic()
print("tool_object", tool_object)
calling = ToolCalling(
tool_name=tool_object["tool_name"],
arguments=tool_object["arguments"],
log=tool_string, # type: ignore
)
print("calling", calling)
if isinstance(calling, ConverterError):
raise calling