mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
fix: Update base_agent_tools and base_tool to fix type errors
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -82,12 +82,12 @@ class BaseAgentTool(BaseTool):
|
|||||||
available_agents = [agent.role for agent in self.agents]
|
available_agents = [agent.role for agent in self.agents]
|
||||||
logger.debug(f"Available agents: {available_agents}")
|
logger.debug(f"Available agents: {available_agents}")
|
||||||
|
|
||||||
agent = [ # type: ignore # Incompatible types in assignment (expression has type "list[BaseAgent]", variable has type "str | None")
|
matching_agents = [
|
||||||
available_agent
|
available_agent
|
||||||
for available_agent in self.agents
|
for available_agent in self.agents
|
||||||
if self.sanitize_agent_name(available_agent.role) == sanitized_name
|
if self.sanitize_agent_name(available_agent.role) == sanitized_name
|
||||||
]
|
]
|
||||||
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
|
logger.debug(f"Found {len(matching_agents)} matching agents for role '{sanitized_name}'")
|
||||||
except (AttributeError, ValueError) as e:
|
except (AttributeError, ValueError) as e:
|
||||||
# Handle specific exceptions that might occur during role name processing
|
# Handle specific exceptions that might occur during role name processing
|
||||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||||
@@ -97,7 +97,7 @@ class BaseAgentTool(BaseTool):
|
|||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not agent:
|
if not matching_agents:
|
||||||
# No matching agent found after sanitization
|
# No matching agent found after sanitization
|
||||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||||
coworkers="\n".join(
|
coworkers="\n".join(
|
||||||
@@ -106,19 +106,19 @@ class BaseAgentTool(BaseTool):
|
|||||||
error=f"No agent found with role '{sanitized_name}'"
|
error=f"No agent found with role '{sanitized_name}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = agent[0]
|
selected_agent = matching_agents[0]
|
||||||
try:
|
try:
|
||||||
task_with_assigned_agent = Task(
|
task_with_assigned_agent = Task(
|
||||||
description=task,
|
description=task,
|
||||||
agent=agent,
|
agent=selected_agent,
|
||||||
expected_output=agent.i18n.slice("manager_request"),
|
expected_output=selected_agent.i18n.slice("manager_request"),
|
||||||
i18n=agent.i18n,
|
i18n=selected_agent.i18n,
|
||||||
)
|
)
|
||||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
|
logger.debug(f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}")
|
||||||
return agent.execute_task(task_with_assigned_agent, context)
|
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle task creation or execution errors
|
# Handle task creation or execution errors
|
||||||
return self.i18n.errors("agent_tool_execution_error").format(
|
return self.i18n.errors("agent_tool_execution_error").format(
|
||||||
agent_role=self.sanitize_agent_name(agent.role),
|
agent_role=self.sanitize_agent_name(selected_agent.role),
|
||||||
error=str(e)
|
error=str(e)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,4 +10,160 @@ def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str,
|
|||||||
"""Helper function to create model fields with proper type hints."""
|
"""Helper function to create model fields with proper type hints."""
|
||||||
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||||
|
|
||||||
# Rest of base_tool.py content...
|
class BaseTool(BaseModel, ABC):
|
||||||
|
"""Base class for all tools."""
|
||||||
|
|
||||||
|
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
func: Optional[Callable] = None
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The unique name of the tool that clearly communicates its purpose."""
|
||||||
|
description: str
|
||||||
|
"""Used to tell the model how/when/why to use the tool."""
|
||||||
|
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||||
|
"""The schema for the arguments that the tool accepts."""
|
||||||
|
description_updated: bool = False
|
||||||
|
"""Flag to check if the description has been updated."""
|
||||||
|
cache_function: Callable = lambda _args=None, _result=None: True
|
||||||
|
"""Function that will be used to determine if the tool should be cached."""
|
||||||
|
result_as_answer: bool = False
|
||||||
|
"""Flag to check if the tool should be the final agent answer."""
|
||||||
|
|
||||||
|
@validator("args_schema", always=True, pre=True)
|
||||||
|
def _default_args_schema(
|
||||||
|
cls, v: Type[PydanticBaseModel]
|
||||||
|
) -> Type[PydanticBaseModel]:
|
||||||
|
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||||
|
return v
|
||||||
|
|
||||||
|
return type(
|
||||||
|
f"{cls.__name__}Schema",
|
||||||
|
(PydanticBaseModel,),
|
||||||
|
{
|
||||||
|
"__annotations__": {
|
||||||
|
k: v for k, v in cls._run.__annotations__.items() if k != "return"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
self._generate_description()
|
||||||
|
super().model_post_init(__context)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
print(f"Using Tool: {self.name}")
|
||||||
|
return self._run(*args, **kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Here goes the actual implementation of the tool."""
|
||||||
|
|
||||||
|
def _set_args_schema(self) -> None:
|
||||||
|
if self.args_schema is None:
|
||||||
|
class_name = f"{self.__class__.__name__}Schema"
|
||||||
|
self.args_schema = type(
|
||||||
|
class_name,
|
||||||
|
(PydanticBaseModel,),
|
||||||
|
{
|
||||||
|
"__annotations__": {
|
||||||
|
k: v
|
||||||
|
for k, v in self._run.__annotations__.items()
|
||||||
|
if k != "return"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_description(self) -> None:
|
||||||
|
args_schema = {
|
||||||
|
name: {
|
||||||
|
"description": field.description,
|
||||||
|
"type": BaseTool._get_arg_annotations(field.annotation),
|
||||||
|
}
|
||||||
|
for name, field in self.args_schema.model_fields.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_arg_annotations(annotation: type[Any] | None) -> str:
|
||||||
|
if annotation is None:
|
||||||
|
return "None"
|
||||||
|
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
args = get_args(annotation)
|
||||||
|
|
||||||
|
if origin is None:
|
||||||
|
return (
|
||||||
|
annotation.__name__
|
||||||
|
if hasattr(annotation, "__name__")
|
||||||
|
else str(annotation)
|
||||||
|
)
|
||||||
|
|
||||||
|
if args:
|
||||||
|
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||||
|
return f"{origin.__name__}[{args_str}]"
|
||||||
|
|
||||||
|
return origin.__name__
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(BaseTool):
|
||||||
|
"""Tool class that wraps a function."""
|
||||||
|
|
||||||
|
func: Callable
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if "func" not in kwargs:
|
||||||
|
raise ValueError("Tool requires a 'func' argument")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def tool(*args: Any) -> Any:
|
||||||
|
"""Decorator to create a tool from a function."""
|
||||||
|
|
||||||
|
def _make_with_name(tool_name: str) -> Callable:
|
||||||
|
def _make_tool(f: Callable) -> Tool:
|
||||||
|
if f.__doc__ is None:
|
||||||
|
raise ValueError("Function must have a docstring")
|
||||||
|
if f.__annotations__ is None:
|
||||||
|
raise ValueError("Function must have type annotations")
|
||||||
|
|
||||||
|
class_name = "".join(tool_name.split()).title()
|
||||||
|
args_schema = type(
|
||||||
|
class_name,
|
||||||
|
(PydanticBaseModel,),
|
||||||
|
{
|
||||||
|
"__annotations__": {
|
||||||
|
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return Tool(
|
||||||
|
name=tool_name,
|
||||||
|
description=f.__doc__,
|
||||||
|
func=f,
|
||||||
|
args_schema=args_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _make_tool
|
||||||
|
|
||||||
|
if len(args) == 1 and callable(args[0]):
|
||||||
|
return _make_with_name(args[0].__name__)(args[0])
|
||||||
|
if len(args) == 1 and isinstance(args[0], str):
|
||||||
|
return _make_with_name(args[0])
|
||||||
|
raise ValueError("Invalid arguments")
|
||||||
|
|||||||
Reference in New Issue
Block a user