feat: add initial changes from langchain

This commit is contained in:
Eduardo Chiarotti
2024-11-22 15:49:40 -03:00
parent 93c0467bba
commit 1edff675eb
4 changed files with 264 additions and 14 deletions

View File

@@ -277,8 +277,8 @@ class Agent(BaseAgent):
if self.crew and self.crew.knowledge: if self.crew and self.crew.knowledge:
knowledge_snippets = self.crew.knowledge.query([task.prompt()]) knowledge_snippets = self.crew.knowledge.query([task.prompt()])
valid_snippets = [ valid_snippets = [
result["context"] result["context"]
for result in knowledge_snippets for result in knowledge_snippets
if result and result.get("context") if result and result.get("context")
] ]
if valid_snippets: if valid_snippets:
@@ -399,7 +399,7 @@ class Agent(BaseAgent):
for tool in tools: for tool in tools:
if isinstance(tool, CrewAITool): if isinstance(tool, CrewAITool):
tools_list.append(tool.to_langchain()) tools_list.append(tool.to_structured_tool())
else: else:
tools_list.append(tool) tools_list.append(tool)
except ModuleNotFoundError: except ModuleNotFoundError:

View File

@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Type, get_args, get_origin from typing import Any, Callable, Type, get_args, get_origin
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field, validator from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic import BaseModel as PydanticBaseModel from pydantic import BaseModel as PydanticBaseModel
from crewai.tools.structured_tool import CrewStructuredTool
class BaseTool(BaseModel, ABC): class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(PydanticBaseModel): class _ArgsSchemaPlaceholder(PydanticBaseModel):
@@ -63,9 +64,10 @@ class BaseTool(BaseModel, ABC):
) -> Any: ) -> Any:
"""Here goes the actual implementation of the tool.""" """Here goes the actual implementation of the tool."""
def to_langchain(self) -> StructuredTool: def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance."""
self._set_args_schema() self._set_args_schema()
return StructuredTool( return CrewStructuredTool(
name=self.name, name=self.name,
description=self.description, description=self.description,
args_schema=self.args_schema, args_schema=self.args_schema,
@@ -73,10 +75,10 @@ class BaseTool(BaseModel, ABC):
) )
@classmethod @classmethod
def from_langchain(cls, tool: StructuredTool) -> "BaseTool": def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
if cls == Tool: if cls == Tool:
if tool.func is None: if tool.func is None:
raise ValueError("StructuredTool must have a callable 'func'") raise ValueError("CrewStructuredTool must have a callable 'func'")
return Tool( return Tool(
name=tool.name, name=tool.name,
description=tool.description, description=tool.description,
@@ -142,9 +144,9 @@ class Tool(BaseTool):
def to_langchain( def to_langchain(
tools: list[BaseTool | StructuredTool], tools: list[BaseTool | CrewStructuredTool],
) -> list[StructuredTool]: ) -> list[CrewStructuredTool]:
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools] return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
def tool(*args): def tool(*args):

View File

@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
from crewai.tools.structured_tool import CrewStructuredTool
class CacheTools(BaseModel): class CacheTools(BaseModel):
@@ -13,9 +14,7 @@ class CacheTools(BaseModel):
) )
def tool(self): def tool(self):
from langchain.tools import StructuredTool return CrewStructuredTool.from_function(
return StructuredTool.from_function(
func=self.hit_cache, func=self.hit_cache,
name=self.name, name=self.name,
description="Reads directly from the cache", description="Reads directly from the cache",

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
import inspect
import textwrap
from typing import Any, Callable, Optional, Union
from pydantic import BaseModel, Field, create_model
from crewai.utilities.logger import Logger
class CrewStructuredTool:
"""A structured tool that can operate on any number of inputs.
This tool replaces LangChain's StructuredTool with a custom implementation
that integrates better with CrewAI's ecosystem.
"""
def __init__(
self,
name: str,
description: str,
args_schema: type[BaseModel],
func: Callable[..., Any],
) -> None:
"""Initialize the structured tool.
Args:
name: The name of the tool
description: A description of what the tool does
args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called
"""
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
self._logger = Logger()
# Validate the function signature matches the schema
self._validate_function_signature()
@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> CrewStructuredTool:
"""Create a tool from a function.
Args:
func: The function to create a tool from
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the output directly
args_schema: Optional schema for the function arguments
infer_schema: Whether to infer the schema from the function signature
**kwargs: Additional arguments to pass to the tool
Returns:
A CrewStructuredTool instance
Example:
>>> def add(a: int, b: int) -> int:
... '''Add two numbers'''
... return a + b
>>> tool = CrewStructuredTool.from_function(add)
"""
name = name or func.__name__
description = description or inspect.getdoc(func)
if description is None:
raise ValueError(
f"Function {name} must have a docstring if description not provided."
)
# Clean up the description
description = textwrap.dedent(description).strip()
if args_schema is not None:
# Use provided schema
schema = args_schema
elif infer_schema:
# Infer schema from function signature
schema = cls._create_schema_from_function(name, func)
else:
raise ValueError(
"Either args_schema must be provided or infer_schema must be True."
)
return cls(
name=name,
description=description,
args_schema=schema,
func=func,
)
@staticmethod
def _create_schema_from_function(
name: str,
func: Callable,
) -> type[BaseModel]:
"""Create a Pydantic schema from a function's signature.
Args:
name: The name to use for the schema
func: The function to create a schema from
Returns:
A Pydantic model class
"""
# Get function signature
sig = inspect.signature(func)
# Get type hints
type_hints = inspect.get_type_hints(func)
# Create field definitions
fields = {}
for param_name, param in sig.parameters.items():
# Skip self/cls for methods
if param_name in ("self", "cls"):
continue
# Get type annotation
annotation = type_hints.get(param_name, Any)
# Get default value
default = ... if param.default == param.empty else param.default
# Add field
fields[param_name] = (annotation, Field(default=default))
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields)
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
sig = inspect.signature(self.func)
schema_fields = self.args_schema.model_fields
# Check required parameters
for param_name, param in sig.parameters.items():
# Skip self/cls for methods
if param_name in ("self", "cls"):
continue
if param.default == inspect.Parameter.empty:
if param_name not in schema_fields:
raise ValueError(
f"Required function parameter '{param_name}' "
f"not found in args_schema"
)
field = schema_fields[param_name]
if field.default == ... and field.default_factory is None:
# Parameter is required in both function and schema
continue
raise ValueError(
f"Function parameter '{param_name}' is required but has a "
f"default value in the schema"
)
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
"""Parse and validate the input arguments against the schema.
Args:
raw_args: The raw arguments to parse, either as a string or dict
Returns:
The validated arguments as a dictionary
"""
if isinstance(raw_args, str):
try:
import json
raw_args = json.loads(raw_args)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}")
try:
validated_args = self.args_schema.model_validate(raw_args)
return validated_args.model_dump()
except Exception as e:
raise ValueError(f"Arguments validation failed: {e}")
async def ainvoke(
self,
input: Union[str, dict],
config: Optional[dict] = None,
**kwargs: Any,
) -> Any:
"""Asynchronously invoke the tool.
Args:
input: The input arguments
config: Optional configuration
**kwargs: Additional keyword arguments
Returns:
The result of the tool execution
"""
parsed_args = self._parse_args(input)
if inspect.iscoroutinefunction(self.func):
return await self.func(**parsed_args, **kwargs)
else:
# Run sync functions in a thread pool
import asyncio
return await asyncio.get_event_loop().run_in_executor(
None, lambda: self.func(**parsed_args, **kwargs)
)
def invoke(
self,
input: Union[str, dict],
config: Optional[dict] = None,
**kwargs: Any,
) -> Any:
"""Synchronously invoke the tool.
Args:
input: The input arguments
config: Optional configuration
**kwargs: Additional keyword arguments
Returns:
The result of the tool execution
"""
parsed_args = self._parse_args(input)
return self.func(**parsed_args, **kwargs)
@property
def args(self) -> dict:
"""Get the tool's input arguments schema."""
return self.args_schema.model_json_schema()["properties"]
def __repr__(self) -> str:
return (
f"CrewStructuredTool(name='{self.name}', description='{self.description}')"
)