diff --git a/src/crewai/tools/agent_tools/base_agent_tools.py b/src/crewai/tools/agent_tools/base_agent_tools.py index b00fbb7b5..ad7db09bd 100644 --- a/src/crewai/tools/agent_tools/base_agent_tools.py +++ b/src/crewai/tools/agent_tools/base_agent_tools.py @@ -82,12 +82,12 @@ class BaseAgentTool(BaseTool): available_agents = [agent.role for agent in self.agents] logger.debug(f"Available agents: {available_agents}") - agent = [ # type: ignore # Incompatible types in assignment (expression has type "list[BaseAgent]", variable has type "str | None") + matching_agents = [ available_agent for available_agent in self.agents if self.sanitize_agent_name(available_agent.role) == sanitized_name ] - logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'") + logger.debug(f"Found {len(matching_agents)} matching agents for role '{sanitized_name}'") except (AttributeError, ValueError) as e: # Handle specific exceptions that might occur during role name processing return self.i18n.errors("agent_tool_unexisting_coworker").format( @@ -97,7 +97,7 @@ class BaseAgentTool(BaseTool): error=str(e) ) - if not agent: + if not matching_agents: # No matching agent found after sanitization return self.i18n.errors("agent_tool_unexisting_coworker").format( coworkers="\n".join( @@ -106,19 +106,19 @@ class BaseAgentTool(BaseTool): error=f"No agent found with role '{sanitized_name}'" ) - agent = agent[0] + selected_agent = matching_agents[0] try: task_with_assigned_agent = Task( description=task, - agent=agent, - expected_output=agent.i18n.slice("manager_request"), - i18n=agent.i18n, + agent=selected_agent, + expected_output=selected_agent.i18n.slice("manager_request"), + i18n=selected_agent.i18n, ) - logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}") - return agent.execute_task(task_with_assigned_agent, context) + logger.debug(f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}") + return selected_agent.execute_task(task_with_assigned_agent, context) except Exception as e: # Handle task creation or execution errors return self.i18n.errors("agent_tool_execution_error").format( - agent_role=self.sanitize_agent_name(agent.role), + agent_role=self.sanitize_agent_name(selected_agent.role), error=str(e) ) diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 24cf64dde..0cec5b498 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -10,4 +10,160 @@ def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, """Helper function to create model fields with proper type hints.""" return {name: (annotation, field) for name, (annotation, field) in fields.items()} -# Rest of base_tool.py content... +class BaseTool(BaseModel, ABC): + """Base class for all tools.""" + + class _ArgsSchemaPlaceholder(PydanticBaseModel): + pass + + model_config = ConfigDict(arbitrary_types_allowed=True) + func: Optional[Callable] = None + + name: str + """The unique name of the tool that clearly communicates its purpose.""" + description: str + """Used to tell the model how/when/why to use the tool.""" + args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder) + """The schema for the arguments that the tool accepts.""" + description_updated: bool = False + """Flag to check if the description has been updated.""" + cache_function: Callable = lambda _args=None, _result=None: True + """Function that will be used to determine if the tool should be cached.""" + result_as_answer: bool = False + """Flag to check if the tool should be the final agent answer.""" + + @validator("args_schema", always=True, pre=True) + def _default_args_schema( + cls, v: Type[PydanticBaseModel] + ) -> Type[PydanticBaseModel]: + if not isinstance(v, cls._ArgsSchemaPlaceholder): + return v + + return type( + f"{cls.__name__}Schema", + (PydanticBaseModel,), + { + "__annotations__": { + k: v for k, v in cls._run.__annotations__.items() if k != "return" + }, + }, + ) + + def model_post_init(self, __context: Any) -> None: + self._generate_description() + super().model_post_init(__context) + + def run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + print(f"Using Tool: {self.name}") + return self._run(*args, **kwargs) + + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Here goes the actual implementation of the tool.""" + + def _set_args_schema(self) -> None: + if self.args_schema is None: + class_name = f"{self.__class__.__name__}Schema" + self.args_schema = type( + class_name, + (PydanticBaseModel,), + { + "__annotations__": { + k: v + for k, v in self._run.__annotations__.items() + if k != "return" + }, + }, + ) + + def _generate_description(self) -> None: + args_schema = { + name: { + "description": field.description, + "type": BaseTool._get_arg_annotations(field.annotation), + } + for name, field in self.args_schema.model_fields.items() + } + + self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}" + + @staticmethod + def _get_arg_annotations(annotation: type[Any] | None) -> str: + if annotation is None: + return "None" + + origin = get_origin(annotation) + args = get_args(annotation) + + if origin is None: + return ( + annotation.__name__ + if hasattr(annotation, "__name__") + else str(annotation) + ) + + if args: + args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args) + return f"{origin.__name__}[{args_str}]" + + return origin.__name__ + + +class Tool(BaseTool): + """Tool class that wraps a function.""" + + func: Callable + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, **kwargs): + if "func" not in kwargs: + raise ValueError("Tool requires a 'func' argument") + super().__init__(**kwargs) + + def _run(self, *args: Any, **kwargs: Any) -> Any: + return self.func(*args, **kwargs) + + +def tool(*args: Any) -> Any: + """Decorator to create a tool from a function.""" + + def _make_with_name(tool_name: str) -> Callable: + def _make_tool(f: Callable) -> Tool: + if f.__doc__ is None: + raise ValueError("Function must have a docstring") + if f.__annotations__ is None: + raise ValueError("Function must have type annotations") + + class_name = "".join(tool_name.split()).title() + args_schema = type( + class_name, + (PydanticBaseModel,), + { + "__annotations__": { + k: v for k, v in f.__annotations__.items() if k != "return" + }, + }, + ) + + return Tool( + name=tool_name, + description=f.__doc__, + func=f, + args_schema=args_schema, + ) + + return _make_tool + + if len(args) == 1 and callable(args[0]): + return _make_with_name(args[0].__name__)(args[0]) + if len(args) == 1 and isinstance(args[0], str): + return _make_with_name(args[0]) + raise ValueError("Invalid arguments")