From 1edff675ebbcc8f80647aacadab7d7f4afcd16b7 Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Fri, 22 Nov 2024 15:49:40 -0300 Subject: [PATCH] feat: add initial changes from langchain --- src/crewai/agent.py | 6 +- src/crewai/tools/base_tool.py | 18 +- src/crewai/tools/cache_tools/cache_tools.py | 5 +- src/crewai/tools/structured_tool.py | 249 ++++++++++++++++++++ 4 files changed, 264 insertions(+), 14 deletions(-) create mode 100644 src/crewai/tools/structured_tool.py diff --git a/src/crewai/agent.py b/src/crewai/agent.py index d17cbbdfe..04ae6eaf8 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -277,8 +277,8 @@ class Agent(BaseAgent): if self.crew and self.crew.knowledge: knowledge_snippets = self.crew.knowledge.query([task.prompt()]) valid_snippets = [ - result["context"] - for result in knowledge_snippets + result["context"] + for result in knowledge_snippets if result and result.get("context") ] if valid_snippets: @@ -399,7 +399,7 @@ class Agent(BaseAgent): for tool in tools: if isinstance(tool, CrewAITool): - tools_list.append(tool.to_langchain()) + tools_list.append(tool.to_structured_tool()) else: tools_list.append(tool) except ModuleNotFoundError: diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index f41fb7c0b..06e427528 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Type, get_args, get_origin -from langchain_core.tools import StructuredTool from pydantic import BaseModel, ConfigDict, Field, validator from pydantic import BaseModel as PydanticBaseModel +from crewai.tools.structured_tool import CrewStructuredTool + class BaseTool(BaseModel, ABC): class _ArgsSchemaPlaceholder(PydanticBaseModel): @@ -63,9 +64,10 @@ class BaseTool(BaseModel, ABC): ) -> Any: """Here goes the actual implementation of the tool.""" - def to_langchain(self) -> StructuredTool: + def to_structured_tool(self) -> CrewStructuredTool: + """Convert this tool to a CrewStructuredTool instance.""" self._set_args_schema() - return StructuredTool( + return CrewStructuredTool( name=self.name, description=self.description, args_schema=self.args_schema, @@ -73,10 +75,10 @@ class BaseTool(BaseModel, ABC): ) @classmethod - def from_langchain(cls, tool: StructuredTool) -> "BaseTool": + def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool": if cls == Tool: if tool.func is None: - raise ValueError("StructuredTool must have a callable 'func'") + raise ValueError("CrewStructuredTool must have a callable 'func'") return Tool( name=tool.name, description=tool.description, @@ -142,9 +144,9 @@ class Tool(BaseTool): def to_langchain( - tools: list[BaseTool | StructuredTool], -) -> list[StructuredTool]: - return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools] + tools: list[BaseTool | CrewStructuredTool], +) -> list[CrewStructuredTool]: + return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] def tool(*args): diff --git a/src/crewai/tools/cache_tools/cache_tools.py b/src/crewai/tools/cache_tools/cache_tools.py index a0bb2dbad..a81ce98cf 100644 --- a/src/crewai/tools/cache_tools/cache_tools.py +++ b/src/crewai/tools/cache_tools/cache_tools.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from crewai.agents.cache import CacheHandler +from crewai.tools.structured_tool import CrewStructuredTool class CacheTools(BaseModel): @@ -13,9 +14,7 @@ class CacheTools(BaseModel): ) def tool(self): - from langchain.tools import StructuredTool - - return StructuredTool.from_function( + return CrewStructuredTool.from_function( func=self.hit_cache, name=self.name, description="Reads directly from the cache", diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py new file mode 100644 index 000000000..dd5d5edb2 --- /dev/null +++ b/src/crewai/tools/structured_tool.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import inspect +import textwrap +from typing import Any, Callable, Optional, Union + +from pydantic import BaseModel, Field, create_model + +from crewai.utilities.logger import Logger + + +class CrewStructuredTool: + """A structured tool that can operate on any number of inputs. + + This tool replaces LangChain's StructuredTool with a custom implementation + that integrates better with CrewAI's ecosystem. + """ + + def __init__( + self, + name: str, + description: str, + args_schema: type[BaseModel], + func: Callable[..., Any], + ) -> None: + """Initialize the structured tool. + + Args: + name: The name of the tool + description: A description of what the tool does + args_schema: The pydantic model for the tool's arguments + func: The function to run when the tool is called + """ + self.name = name + self.description = description + self.args_schema = args_schema + self.func = func + self._logger = Logger() + + # Validate the function signature matches the schema + self._validate_function_signature() + + @classmethod + def from_function( + cls, + func: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> CrewStructuredTool: + """Create a tool from a function. + + Args: + func: The function to create a tool from + name: The name of the tool. Defaults to the function name + description: The description of the tool. Defaults to the function docstring + return_direct: Whether to return the output directly + args_schema: Optional schema for the function arguments + infer_schema: Whether to infer the schema from the function signature + **kwargs: Additional arguments to pass to the tool + + Returns: + A CrewStructuredTool instance + + Example: + >>> def add(a: int, b: int) -> int: + ... '''Add two numbers''' + ... return a + b + >>> tool = CrewStructuredTool.from_function(add) + """ + name = name or func.__name__ + description = description or inspect.getdoc(func) + + if description is None: + raise ValueError( + f"Function {name} must have a docstring if description not provided." + ) + + # Clean up the description + description = textwrap.dedent(description).strip() + + if args_schema is not None: + # Use provided schema + schema = args_schema + elif infer_schema: + # Infer schema from function signature + schema = cls._create_schema_from_function(name, func) + else: + raise ValueError( + "Either args_schema must be provided or infer_schema must be True." + ) + + return cls( + name=name, + description=description, + args_schema=schema, + func=func, + ) + + @staticmethod + def _create_schema_from_function( + name: str, + func: Callable, + ) -> type[BaseModel]: + """Create a Pydantic schema from a function's signature. + + Args: + name: The name to use for the schema + func: The function to create a schema from + + Returns: + A Pydantic model class + """ + # Get function signature + sig = inspect.signature(func) + + # Get type hints + type_hints = inspect.get_type_hints(func) + + # Create field definitions + fields = {} + for param_name, param in sig.parameters.items(): + # Skip self/cls for methods + if param_name in ("self", "cls"): + continue + + # Get type annotation + annotation = type_hints.get(param_name, Any) + + # Get default value + default = ... if param.default == param.empty else param.default + + # Add field + fields[param_name] = (annotation, Field(default=default)) + + # Create model + schema_name = f"{name.title()}Schema" + return create_model(schema_name, **fields) + + def _validate_function_signature(self) -> None: + """Validate that the function signature matches the args schema.""" + sig = inspect.signature(self.func) + schema_fields = self.args_schema.model_fields + + # Check required parameters + for param_name, param in sig.parameters.items(): + # Skip self/cls for methods + if param_name in ("self", "cls"): + continue + + if param.default == inspect.Parameter.empty: + if param_name not in schema_fields: + raise ValueError( + f"Required function parameter '{param_name}' " + f"not found in args_schema" + ) + + field = schema_fields[param_name] + if field.default == ... and field.default_factory is None: + # Parameter is required in both function and schema + continue + + raise ValueError( + f"Function parameter '{param_name}' is required but has a " + f"default value in the schema" + ) + + def _parse_args(self, raw_args: Union[str, dict]) -> dict: + """Parse and validate the input arguments against the schema. + + Args: + raw_args: The raw arguments to parse, either as a string or dict + + Returns: + The validated arguments as a dictionary + """ + if isinstance(raw_args, str): + try: + import json + + raw_args = json.loads(raw_args) + except json.JSONDecodeError as e: + raise ValueError(f"Failed to parse arguments as JSON: {e}") + + try: + validated_args = self.args_schema.model_validate(raw_args) + return validated_args.model_dump() + except Exception as e: + raise ValueError(f"Arguments validation failed: {e}") + + async def ainvoke( + self, + input: Union[str, dict], + config: Optional[dict] = None, + **kwargs: Any, + ) -> Any: + """Asynchronously invoke the tool. + + Args: + input: The input arguments + config: Optional configuration + **kwargs: Additional keyword arguments + + Returns: + The result of the tool execution + """ + parsed_args = self._parse_args(input) + + if inspect.iscoroutinefunction(self.func): + return await self.func(**parsed_args, **kwargs) + else: + # Run sync functions in a thread pool + import asyncio + + return await asyncio.get_event_loop().run_in_executor( + None, lambda: self.func(**parsed_args, **kwargs) + ) + + def invoke( + self, + input: Union[str, dict], + config: Optional[dict] = None, + **kwargs: Any, + ) -> Any: + """Synchronously invoke the tool. + + Args: + input: The input arguments + config: Optional configuration + **kwargs: Additional keyword arguments + + Returns: + The result of the tool execution + """ + parsed_args = self._parse_args(input) + return self.func(**parsed_args, **kwargs) + + @property + def args(self) -> dict: + """Get the tool's input arguments schema.""" + return self.args_schema.model_json_schema()["properties"] + + def __repr__(self) -> str: + return ( + f"CrewStructuredTool(name='{self.name}', description='{self.description}')" + )