Feature/kickoff for each sync (#680)

* Sync with deep copy working now

* async working!!

* Clean up code for review

* Fix naming

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-06-11 11:51:39 -04:00
committed by GitHub
parent 2a0e21ca76
commit 946c56494e
3 changed files with 190 additions and 39 deletions

View File

@@ -1,3 +1,4 @@
from copy import deepcopy
import os
import re
import threading
@@ -163,13 +164,16 @@ class Task(BaseModel):
)
if self.context:
context = [] # type: ignore # Incompatible types in assignment (expression has type "list[Never]", variable has type "str | None")
# type: ignore # Incompatible types in assignment (expression has type "list[Never]", variable has type "str | None")
context = []
for task in self.context:
if task.async_execution:
task.thread.join() # type: ignore # Item "None" of "Thread | None" has no attribute "join"
if task and task.output:
context.append(task.output.raw_output) # type: ignore # Item "str" of "str | None" has no attribute "append"
context = "\n".join(context) # type: ignore # Argument 1 to "join" of "str" has incompatible type "str | None"; expected "Iterable[str]"
# type: ignore # Item "str" of "str | None" has no attribute "append"
context.append(task.output.raw_output)
# type: ignore # Argument 1 to "join" of "str" has incompatible type "str | None"; expected "Iterable[str]"
context = "\n".join(context)
self.prompt_context = context
tools = tools or self.tools
@@ -232,7 +236,8 @@ class Task(BaseModel):
if inputs:
self.description = self._original_description.format(**inputs)
self.expected_output = self._original_expected_output.format(**inputs)
self.expected_output = self._original_expected_output.format(
**inputs)
def increment_tools_errors(self) -> None:
"""Increment the tools errors counter."""
@@ -242,6 +247,25 @@ class Task(BaseModel):
"""Increment the delegations counter."""
self.delegations += 1
def copy(self):
"""Create a deep copy of the Task."""
exclude = {
"id",
"agent",
"context",
"tools",
}
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
cloned_context = [task.copy() for task in self.context] if self.context else None
cloned_agent = self.agent.copy() if self.agent else None
cloned_tools = deepcopy(self.tools) if self.tools else None
copied_task = Task(**copied_data, context=cloned_context, agent=cloned_agent, tools=cloned_tools)
return copied_task
def _export_output(self, result: str) -> Any:
exported_result = result
instructions = "I'm gonna convert this raw text into valid JSON."
@@ -251,27 +275,35 @@ class Task(BaseModel):
# try to convert task_output directly to pydantic/json
try:
exported_result = model.model_validate_json(result) # type: ignore # Item "None" of "type[BaseModel] | None" has no attribute "model_validate_json"
# type: ignore # Item "None" of "type[BaseModel] | None" has no attribute "model_validate_json"
exported_result = model.model_validate_json(result)
if self.output_json:
return exported_result.model_dump() # type: ignore # "str" has no attribute "model_dump"
# type: ignore # "str" has no attribute "model_dump"
return exported_result.model_dump()
return exported_result
except Exception:
# sometimes the response contains valid JSON in the middle of text
match = re.search(r"({.*})", result, re.DOTALL)
if match:
try:
exported_result = model.model_validate_json(match.group(0)) # type: ignore # Item "None" of "type[BaseModel] | None" has no attribute "model_validate_json"
# type: ignore # Item "None" of "type[BaseModel] | None" has no attribute "model_validate_json"
exported_result = model.model_validate_json(
match.group(0))
if self.output_json:
return exported_result.model_dump() # type: ignore # "str" has no attribute "model_dump"
# type: ignore # "str" has no attribute "model_dump"
return exported_result.model_dump()
return exported_result
except Exception:
pass
llm = self.agent.function_calling_llm or self.agent.llm # type: ignore # Item "None" of "Agent | None" has no attribute "function_calling_llm"
# type: ignore # Item "None" of "Agent | None" has no attribute "function_calling_llm"
llm = self.agent.function_calling_llm or self.agent.llm
if not self._is_gpt(llm):
model_schema = PydanticSchemaParser(model=model).get_schema() # type: ignore # Argument "model" to "PydanticSchemaParser" has incompatible type "type[BaseModel] | None"; expected "type[BaseModel]"
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
# type: ignore # Argument "model" to "PydanticSchemaParser" has incompatible type "type[BaseModel] | None"; expected "type[BaseModel]"
model_schema = PydanticSchemaParser(model=model).get_schema()
instructions = f"{
instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
converter = Converter(
llm=llm, text=result, model=model, instructions=instructions
@@ -284,14 +316,16 @@ class Task(BaseModel):
if isinstance(exported_result, ConverterError):
Printer().print(
content=f"{exported_result.message} Using raw output instead.",
content=f"{
exported_result.message} Using raw output instead.",
color="red",
)
exported_result = result
if self.output_file:
content = (
exported_result if not self.output_pydantic else exported_result.json() # type: ignore # "str" has no attribute "json"
# type: ignore # "str" has no attribute "json"
exported_result if not self.output_pydantic else exported_result.json()
)
self._save_file(content)
@@ -301,12 +335,14 @@ class Task(BaseModel):
return isinstance(llm, ChatOpenAI) and llm.openai_api_base is None
def _save_file(self, result: Any) -> None:
directory = os.path.dirname(self.output_file) # type: ignore # Value of type variable "AnyOrLiteralStr" of "dirname" cannot be "str | None"
# type: ignore # Value of type variable "AnyOrLiteralStr" of "dirname" cannot be "str | None"
directory = os.path.dirname(self.output_file)
if directory and not os.path.exists(directory):
os.makedirs(directory)
with open(self.output_file, "w", encoding="utf-8") as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
# type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
with open(self.output_file, "w", encoding='utf-8') as file:
file.write(result)
return None