Compare commits

...

1 Commits

Author SHA1 Message Date
Gui Vieira
9a210afd80 Fix types 2024-02-08 18:34:04 -03:00
6 changed files with 29 additions and 26 deletions

View File

@@ -1,6 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from langchain_core.agents import AgentAction
from pydantic.v1 import BaseModel, Field
from .cache_handler import CacheHandler
@@ -11,8 +10,5 @@ class CacheHit(BaseModel):
class Config:
arbitrary_types_allowed = True
# Making it Any instead of AgentAction to avoind
# pydantic v1 vs v2 incompatibility, langchain should
# soon be updated to pydantic v2
action: Any = Field(description="Action taken")
action: AgentAction = Field(description="Action taken")
cache: CacheHandler = Field(description="Cache Handler for the tool")

View File

@@ -106,15 +106,16 @@ class CrewAgentExecutor(AgentExecutor):
**inputs,
)
if self._should_force_answer():
if isinstance(output, AgentAction) or isinstance(output, AgentFinish):
output = output
elif isinstance(output, CacheHit):
if isinstance(output, CacheHit):
output = output.action
else:
raise ValueError(
f"Unexpected output type from agent: {type(output)}"
)
yield self._force_answer(output)
if isinstance(output, AgentAction):
yield self._force_answer(output)
return
if isinstance(output, list):
yield from [self._force_answer(action) for action in output]
return
yield output
return
except OutputParserException as e:

View File

@@ -73,7 +73,7 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
)
if self.cache.read(action, tool_input):
action = AgentAction(action, tool_input, text)
return CacheHit(action=action, cache=self.cache)
agent_action = AgentAction(action, tool_input, text)
return CacheHit(action=agent_action, cache=self.cache)
return super().parse(text)

View File

@@ -171,7 +171,7 @@ class Crew(BaseModel):
def _run_sequential_process(self) -> str:
"""Executes tasks sequentially and returns the final output."""
task_output = ""
task_output: str = ""
for task in self.tasks:
if task.agent is not None and task.agent.allow_delegation:
agents_for_delegation = [
@@ -185,6 +185,7 @@ class Crew(BaseModel):
output = task.execute(context=task_output)
if not task.async_execution:
assert output is not None
task_output = output
role = task.agent.role if task.agent is not None else "None"
@@ -208,14 +209,17 @@ class Crew(BaseModel):
verbose=True,
)
task_output = ""
task_output: str = ""
for task in self.tasks:
self._logger.log("debug", f"Working Agent: {manager.role}")
self._logger.log("info", f"Starting Task: {task.description}")
task_output = task.execute(
output = task.execute(
agent=manager, context=task_output, tools=manager.tools
)
if not task.async_execution:
assert output is not None
task_output = output
self._logger.log(
"debug", f"[{manager.role}] Task output: {task_output}\n\n"

View File

@@ -18,7 +18,7 @@ class Task(BaseModel):
__hash__ = object.__hash__ # type: ignore
i18n: I18N = I18N()
thread: threading.Thread = None
thread: threading.Thread | None = None
description: str = Field(description="Description of the actual task.")
callback: Optional[Any] = Field(
description="Callback to be executed after the task is completed.", default=None
@@ -71,7 +71,7 @@ class Task(BaseModel):
agent: Agent | None = None,
context: Optional[str] = None,
tools: Optional[List[Any]] = None,
) -> str:
) -> str | None:
"""Execute the task.
Returns:
@@ -85,12 +85,14 @@ class Task(BaseModel):
)
if self.context:
context = []
results = []
for task in self.context:
if task.async_execution:
assert task.thread is not None
task.thread.join()
context.append(task.output.result)
context = "\n".join(context)
if task.output is not None:
results.append(task.output.result)
context = "\n".join(results)
tools = tools or self.tools

View File

@@ -411,7 +411,7 @@ def test_async_task_execution():
with patch.object(threading.Thread, "start") as start:
thread = threading.Thread(target=lambda: None, args=()).start()
start.return_value = thread
with patch.object(threading.Thread, "join", wraps=thread.join()) as join:
with patch.object(threading.Thread, "join", wraps=thread.join()) as join: # type: ignore
list_ideas.output = TaskOutput(
description="A 4 paragraph article about AI.", result="ok"
)