Fix types

This commit is contained in:
Gui Vieira
2024-02-08 18:34:04 -03:00
parent 44b6bcbcaa
commit 9a210afd80
6 changed files with 29 additions and 26 deletions

View File

@@ -1,6 +1,5 @@
from typing import Any from langchain_core.agents import AgentAction
from pydantic.v1 import BaseModel, Field
from pydantic import BaseModel, Field
from .cache_handler import CacheHandler from .cache_handler import CacheHandler
@@ -11,8 +10,5 @@ class CacheHit(BaseModel):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
# Making it Any instead of AgentAction to avoind action: AgentAction = Field(description="Action taken")
# pydantic v1 vs v2 incompatibility, langchain should
# soon be updated to pydantic v2
action: Any = Field(description="Action taken")
cache: CacheHandler = Field(description="Cache Handler for the tool") cache: CacheHandler = Field(description="Cache Handler for the tool")

View File

@@ -106,15 +106,16 @@ class CrewAgentExecutor(AgentExecutor):
**inputs, **inputs,
) )
if self._should_force_answer(): if self._should_force_answer():
if isinstance(output, AgentAction) or isinstance(output, AgentFinish): if isinstance(output, CacheHit):
output = output
elif isinstance(output, CacheHit):
output = output.action output = output.action
else: if isinstance(output, AgentAction):
raise ValueError( yield self._force_answer(output)
f"Unexpected output type from agent: {type(output)}" return
) if isinstance(output, list):
yield self._force_answer(output) yield from [self._force_answer(action) for action in output]
return
yield output
return return
except OutputParserException as e: except OutputParserException as e:

View File

@@ -73,7 +73,7 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
) )
if self.cache.read(action, tool_input): if self.cache.read(action, tool_input):
action = AgentAction(action, tool_input, text) agent_action = AgentAction(action, tool_input, text)
return CacheHit(action=action, cache=self.cache) return CacheHit(action=agent_action, cache=self.cache)
return super().parse(text) return super().parse(text)

View File

@@ -171,7 +171,7 @@ class Crew(BaseModel):
def _run_sequential_process(self) -> str: def _run_sequential_process(self) -> str:
"""Executes tasks sequentially and returns the final output.""" """Executes tasks sequentially and returns the final output."""
task_output = "" task_output: str = ""
for task in self.tasks: for task in self.tasks:
if task.agent is not None and task.agent.allow_delegation: if task.agent is not None and task.agent.allow_delegation:
agents_for_delegation = [ agents_for_delegation = [
@@ -185,6 +185,7 @@ class Crew(BaseModel):
output = task.execute(context=task_output) output = task.execute(context=task_output)
if not task.async_execution: if not task.async_execution:
assert output is not None
task_output = output task_output = output
role = task.agent.role if task.agent is not None else "None" role = task.agent.role if task.agent is not None else "None"
@@ -208,14 +209,17 @@ class Crew(BaseModel):
verbose=True, verbose=True,
) )
task_output = "" task_output: str = ""
for task in self.tasks: for task in self.tasks:
self._logger.log("debug", f"Working Agent: {manager.role}") self._logger.log("debug", f"Working Agent: {manager.role}")
self._logger.log("info", f"Starting Task: {task.description}") self._logger.log("info", f"Starting Task: {task.description}")
task_output = task.execute( output = task.execute(
agent=manager, context=task_output, tools=manager.tools agent=manager, context=task_output, tools=manager.tools
) )
if not task.async_execution:
assert output is not None
task_output = output
self._logger.log( self._logger.log(
"debug", f"[{manager.role}] Task output: {task_output}\n\n" "debug", f"[{manager.role}] Task output: {task_output}\n\n"

View File

@@ -18,7 +18,7 @@ class Task(BaseModel):
__hash__ = object.__hash__ # type: ignore __hash__ = object.__hash__ # type: ignore
i18n: I18N = I18N() i18n: I18N = I18N()
thread: threading.Thread = None thread: threading.Thread | None = None
description: str = Field(description="Description of the actual task.") description: str = Field(description="Description of the actual task.")
callback: Optional[Any] = Field( callback: Optional[Any] = Field(
description="Callback to be executed after the task is completed.", default=None description="Callback to be executed after the task is completed.", default=None
@@ -71,7 +71,7 @@ class Task(BaseModel):
agent: Agent | None = None, agent: Agent | None = None,
context: Optional[str] = None, context: Optional[str] = None,
tools: Optional[List[Any]] = None, tools: Optional[List[Any]] = None,
) -> str: ) -> str | None:
"""Execute the task. """Execute the task.
Returns: Returns:
@@ -85,12 +85,14 @@ class Task(BaseModel):
) )
if self.context: if self.context:
context = [] results = []
for task in self.context: for task in self.context:
if task.async_execution: if task.async_execution:
assert task.thread is not None
task.thread.join() task.thread.join()
context.append(task.output.result) if task.output is not None:
context = "\n".join(context) results.append(task.output.result)
context = "\n".join(results)
tools = tools or self.tools tools = tools or self.tools

View File

@@ -411,7 +411,7 @@ def test_async_task_execution():
with patch.object(threading.Thread, "start") as start: with patch.object(threading.Thread, "start") as start:
thread = threading.Thread(target=lambda: None, args=()).start() thread = threading.Thread(target=lambda: None, args=()).start()
start.return_value = thread start.return_value = thread
with patch.object(threading.Thread, "join", wraps=thread.join()) as join: with patch.object(threading.Thread, "join", wraps=thread.join()) as join: # type: ignore
list_ideas.output = TaskOutput( list_ideas.output = TaskOutput(
description="A 4 paragraph article about AI.", result="ok" description="A 4 paragraph article about AI.", result="ok"
) )