mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: Update base_agent_tools and base_tool to fix type errors
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -82,12 +82,12 @@ class BaseAgentTool(BaseTool):
|
||||
available_agents = [agent.role for agent in self.agents]
|
||||
logger.debug(f"Available agents: {available_agents}")
|
||||
|
||||
agent = [ # type: ignore # Incompatible types in assignment (expression has type "list[BaseAgent]", variable has type "str | None")
|
||||
matching_agents = [
|
||||
available_agent
|
||||
for available_agent in self.agents
|
||||
if self.sanitize_agent_name(available_agent.role) == sanitized_name
|
||||
]
|
||||
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
|
||||
logger.debug(f"Found {len(matching_agents)} matching agents for role '{sanitized_name}'")
|
||||
except (AttributeError, ValueError) as e:
|
||||
# Handle specific exceptions that might occur during role name processing
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
@@ -97,7 +97,7 @@ class BaseAgentTool(BaseTool):
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if not agent:
|
||||
if not matching_agents:
|
||||
# No matching agent found after sanitization
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
@@ -106,19 +106,19 @@ class BaseAgentTool(BaseTool):
|
||||
error=f"No agent found with role '{sanitized_name}'"
|
||||
)
|
||||
|
||||
agent = agent[0]
|
||||
selected_agent = matching_agents[0]
|
||||
try:
|
||||
task_with_assigned_agent = Task(
|
||||
description=task,
|
||||
agent=agent,
|
||||
expected_output=agent.i18n.slice("manager_request"),
|
||||
i18n=agent.i18n,
|
||||
agent=selected_agent,
|
||||
expected_output=selected_agent.i18n.slice("manager_request"),
|
||||
i18n=selected_agent.i18n,
|
||||
)
|
||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
|
||||
return agent.execute_task(task_with_assigned_agent, context)
|
||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}")
|
||||
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return self.i18n.errors("agent_tool_execution_error").format(
|
||||
agent_role=self.sanitize_agent_name(agent.role),
|
||||
agent_role=self.sanitize_agent_name(selected_agent.role),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -10,4 +10,160 @@ def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str,
|
||||
"""Helper function to create model fields with proper type hints."""
|
||||
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||
|
||||
# Rest of base_tool.py content...
|
||||
class BaseTool(BaseModel, ABC):
|
||||
"""Base class for all tools."""
|
||||
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
func: Optional[Callable] = None
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
cache_function: Callable = lambda _args=None, _result=None: True
|
||||
"""Function that will be used to determine if the tool should be cached."""
|
||||
result_as_answer: bool = False
|
||||
"""Flag to check if the tool should be the final agent answer."""
|
||||
|
||||
@validator("args_schema", always=True, pre=True)
|
||||
def _default_args_schema(
|
||||
cls, v: Type[PydanticBaseModel]
|
||||
) -> Type[PydanticBaseModel]:
|
||||
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||
return v
|
||||
|
||||
return type(
|
||||
f"{cls.__name__}Schema",
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in cls._run.__annotations__.items() if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._generate_description()
|
||||
super().model_post_init(__context)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
print(f"Using Tool: {self.name}")
|
||||
return self._run(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
|
||||
def _set_args_schema(self) -> None:
|
||||
if self.args_schema is None:
|
||||
class_name = f"{self.__class__.__name__}Schema"
|
||||
self.args_schema = type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v
|
||||
for k, v in self._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
args_schema = {
|
||||
name: {
|
||||
"description": field.description,
|
||||
"type": BaseTool._get_arg_annotations(field.annotation),
|
||||
}
|
||||
for name, field in self.args_schema.model_fields.items()
|
||||
}
|
||||
|
||||
self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}"
|
||||
|
||||
@staticmethod
|
||||
def _get_arg_annotations(annotation: type[Any] | None) -> str:
|
||||
if annotation is None:
|
||||
return "None"
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin is None:
|
||||
return (
|
||||
annotation.__name__
|
||||
if hasattr(annotation, "__name__")
|
||||
else str(annotation)
|
||||
)
|
||||
|
||||
if args:
|
||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||
return f"{origin.__name__}[{args_str}]"
|
||||
|
||||
return origin.__name__
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""Tool class that wraps a function."""
|
||||
|
||||
func: Callable
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "func" not in kwargs:
|
||||
raise ValueError("Tool requires a 'func' argument")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def tool(*args: Any) -> Any:
|
||||
"""Decorator to create a tool from a function."""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> Tool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
raise ValueError("Function must have type annotations")
|
||||
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
Reference in New Issue
Block a user