doc strings

This commit is contained in:
Brandon Hancock
2024-11-27 11:04:24 -05:00
parent 94176552f3
commit e748615c9c
2 changed files with 64 additions and 12 deletions

View File

@@ -138,7 +138,14 @@ class BaseAgent(ABC, BaseModel):
@field_validator("tools")
@classmethod
def validate_tools(cls, tools):
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
"""
processed_tools = []
for tool in tools:
if isinstance(tool, BaseTool):

View File

@@ -76,17 +76,47 @@ class BaseTool(BaseModel, ABC):
)
@classmethod
def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
if cls == Tool:
if tool.func is None:
raise ValueError("CrewStructuredTool must have a callable 'func'")
return Tool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
func=tool.func,
)
raise NotImplementedError(f"from_langchain not implemented for {cls.__name__}")
def from_langchain(cls, tool: Any) -> "BaseTool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
return cls(
name=getattr(tool, "name", "Unnamed Tool"),
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
)
def _set_args_schema(self):
if self.args_schema is None:
@@ -146,6 +176,21 @@ class Tool(BaseTool):
@classmethod
def from_langchain(cls, tool: Any) -> "Tool":
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
Tool instance. It ensures that the provided tool has a callable 'func'
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
"""
if not hasattr(tool, "func") or not callable(tool.func):
raise ValueError("The provided tool must have a callable 'func' attribute.")