diff --git a/src/crewai/agents/cache/cache_hit.py b/src/crewai/agents/cache/cache_hit.py index 07a711c29..699c6bb14 100644 --- a/src/crewai/agents/cache/cache_hit.py +++ b/src/crewai/agents/cache/cache_hit.py @@ -1,6 +1,5 @@ -from typing import Any - -from pydantic import BaseModel, Field +from langchain_core.agents import AgentAction +from pydantic.v1 import BaseModel, Field from .cache_handler import CacheHandler @@ -11,8 +10,5 @@ class CacheHit(BaseModel): class Config: arbitrary_types_allowed = True - # Making it Any instead of AgentAction to avoind - # pydantic v1 vs v2 incompatibility, langchain should - # soon be updated to pydantic v2 - action: Any = Field(description="Action taken") + action: AgentAction = Field(description="Action taken") cache: CacheHandler = Field(description="Cache Handler for the tool") diff --git a/src/crewai/agents/executor.py b/src/crewai/agents/executor.py index bacf0b39d..65fd333ff 100644 --- a/src/crewai/agents/executor.py +++ b/src/crewai/agents/executor.py @@ -106,15 +106,16 @@ class CrewAgentExecutor(AgentExecutor): **inputs, ) if self._should_force_answer(): - if isinstance(output, AgentAction) or isinstance(output, AgentFinish): - output = output - elif isinstance(output, CacheHit): + if isinstance(output, CacheHit): output = output.action - else: - raise ValueError( - f"Unexpected output type from agent: {type(output)}" - ) - yield self._force_answer(output) + if isinstance(output, AgentAction): + yield self._force_answer(output) + return + if isinstance(output, list): + yield from [self._force_answer(action) for action in output] + return + + yield output return except OutputParserException as e: diff --git a/src/crewai/agents/output_parser.py b/src/crewai/agents/output_parser.py index 9edeb12b0..740a83451 100644 --- a/src/crewai/agents/output_parser.py +++ b/src/crewai/agents/output_parser.py @@ -73,7 +73,7 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser): ) if self.cache.read(action, tool_input): - action = AgentAction(action, tool_input, text) - return CacheHit(action=action, cache=self.cache) + agent_action = AgentAction(action, tool_input, text) + return CacheHit(action=agent_action, cache=self.cache) return super().parse(text) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 16176f7da..b29062316 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -171,7 +171,7 @@ class Crew(BaseModel): def _run_sequential_process(self) -> str: """Executes tasks sequentially and returns the final output.""" - task_output = "" + task_output: str = "" for task in self.tasks: if task.agent is not None and task.agent.allow_delegation: agents_for_delegation = [ @@ -185,6 +185,7 @@ class Crew(BaseModel): output = task.execute(context=task_output) if not task.async_execution: + assert output is not None task_output = output role = task.agent.role if task.agent is not None else "None" @@ -208,14 +209,17 @@ class Crew(BaseModel): verbose=True, ) - task_output = "" + task_output: str = "" for task in self.tasks: self._logger.log("debug", f"Working Agent: {manager.role}") self._logger.log("info", f"Starting Task: {task.description}") - task_output = task.execute( + output = task.execute( agent=manager, context=task_output, tools=manager.tools ) + if not task.async_execution: + assert output is not None + task_output = output self._logger.log( "debug", f"[{manager.role}] Task output: {task_output}\n\n" diff --git a/src/crewai/task.py b/src/crewai/task.py index 18d90aabf..fc590f627 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -18,7 +18,7 @@ class Task(BaseModel): __hash__ = object.__hash__ # type: ignore i18n: I18N = I18N() - thread: threading.Thread = None + thread: threading.Thread | None = None description: str = Field(description="Description of the actual task.") callback: Optional[Any] = Field( description="Callback to be executed after the task is completed.", default=None @@ -71,7 +71,7 @@ class Task(BaseModel): agent: Agent | None = None, context: Optional[str] = None, tools: Optional[List[Any]] = None, - ) -> str: + ) -> str | None: """Execute the task. Returns: @@ -85,12 +85,14 @@ class Task(BaseModel): ) if self.context: - context = [] + results = [] for task in self.context: if task.async_execution: + assert task.thread is not None task.thread.join() - context.append(task.output.result) - context = "\n".join(context) + if task.output is not None: + results.append(task.output.result) + context = "\n".join(results) tools = tools or self.tools diff --git a/tests/crew_test.py b/tests/crew_test.py index 57bf715fc..5d123d3da 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -411,7 +411,7 @@ def test_async_task_execution(): with patch.object(threading.Thread, "start") as start: thread = threading.Thread(target=lambda: None, args=()).start() start.return_value = thread - with patch.object(threading.Thread, "join", wraps=thread.join()) as join: + with patch.object(threading.Thread, "join", wraps=thread.join()) as join: # type: ignore list_ideas.output = TaskOutput( description="A 4 paragraph article about AI.", result="ok" )