mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Refactor Flow class to improve code quality
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -17,6 +17,11 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
# Forward reference for type hints
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent import Agent
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
from crewai.flow.flow_visualizer import plot_flow
|
from crewai.flow.flow_visualizer import plot_flow
|
||||||
@@ -788,27 +793,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||||
|
|
||||||
# Check if any agent in the flow has tools_results with result_as_answer=True
|
# Check for tool results with result_as_answer=True in method-associated agents
|
||||||
for method_name, method in self._methods.items():
|
for method_name, method in self._methods.items():
|
||||||
if hasattr(method, "__agent"):
|
if hasattr(method, "__agent"):
|
||||||
agent = getattr(method, "__agent")
|
tool_result = self._check_tool_results(getattr(method, "__agent"))
|
||||||
if hasattr(agent, "tools_results") and agent.tools_results:
|
if tool_result:
|
||||||
for tool_result in agent.tools_results:
|
final_output = tool_result
|
||||||
if tool_result.get("result_as_answer", False):
|
break
|
||||||
final_output = tool_result["result"]
|
|
||||||
break
|
|
||||||
|
|
||||||
# Also check for any agents that might be stored as instance variables
|
# Also check for agents stored as instance variables
|
||||||
for attr_name in dir(self):
|
for attr_name in dir(self):
|
||||||
if attr_name.startswith('_'):
|
if attr_name.startswith('_'):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
attr = getattr(self, attr_name)
|
attr = getattr(self, attr_name)
|
||||||
if hasattr(attr, "tools_results") and attr.tools_results:
|
tool_result = self._check_tool_results(attr)
|
||||||
for tool_result in attr.tools_results:
|
if tool_result:
|
||||||
if tool_result.get("result_as_answer", False):
|
final_output = tool_result
|
||||||
final_output = tool_result["result"]
|
break
|
||||||
break
|
|
||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1094,8 +1096,28 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
elif level == "warning":
|
elif level == "warning":
|
||||||
logger.warning(message)
|
logger.warning(message)
|
||||||
|
|
||||||
|
def _check_tool_results(self, obj) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Check if an object has tool results with result_as_answer=True.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
obj : Any
|
||||||
|
The object to check for tool results
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Optional[str]
|
||||||
|
The tool result if found with result_as_answer=True, None otherwise
|
||||||
|
"""
|
||||||
|
if hasattr(obj, "tools_results") and obj.tools_results:
|
||||||
|
for tool_result in obj.tools_results:
|
||||||
|
if tool_result.get("result_as_answer", False):
|
||||||
|
return tool_result["result"]
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def with_agent(cls, agent):
|
def with_agent(cls, agent: 'Agent') -> Callable:
|
||||||
"""
|
"""
|
||||||
Decorator to associate an agent with a flow method.
|
Decorator to associate an agent with a flow method.
|
||||||
This allows tracking which agents are used in the flow.
|
This allows tracking which agents are used in the flow.
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
import pytest
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai import Agent, Crew, Flow, Task
|
from crewai import Agent, Crew, Flow, Task
|
||||||
from crewai.flow import listen, start
|
from crewai.flow import listen, start
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class TestToolInput(BaseModel):
|
class TestToolInput(BaseModel):
|
||||||
query: str = Field(..., description='Query to process')
|
query: str = Field(..., description='Query to process')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user