mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: add initial changes from langchain
This commit is contained in:
@@ -277,8 +277,8 @@ class Agent(BaseAgent):
|
|||||||
if self.crew and self.crew.knowledge:
|
if self.crew and self.crew.knowledge:
|
||||||
knowledge_snippets = self.crew.knowledge.query([task.prompt()])
|
knowledge_snippets = self.crew.knowledge.query([task.prompt()])
|
||||||
valid_snippets = [
|
valid_snippets = [
|
||||||
result["context"]
|
result["context"]
|
||||||
for result in knowledge_snippets
|
for result in knowledge_snippets
|
||||||
if result and result.get("context")
|
if result and result.get("context")
|
||||||
]
|
]
|
||||||
if valid_snippets:
|
if valid_snippets:
|
||||||
@@ -399,7 +399,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if isinstance(tool, CrewAITool):
|
if isinstance(tool, CrewAITool):
|
||||||
tools_list.append(tool.to_langchain())
|
tools_list.append(tool.to_structured_tool())
|
||||||
else:
|
else:
|
||||||
tools_list.append(tool)
|
tools_list.append(tool)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Type, get_args, get_origin
|
from typing import Any, Callable, Type, get_args, get_origin
|
||||||
|
|
||||||
from langchain_core.tools import StructuredTool
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(BaseModel, ABC):
|
class BaseTool(BaseModel, ABC):
|
||||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||||
@@ -63,9 +64,10 @@ class BaseTool(BaseModel, ABC):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Here goes the actual implementation of the tool."""
|
"""Here goes the actual implementation of the tool."""
|
||||||
|
|
||||||
def to_langchain(self) -> StructuredTool:
|
def to_structured_tool(self) -> CrewStructuredTool:
|
||||||
|
"""Convert this tool to a CrewStructuredTool instance."""
|
||||||
self._set_args_schema()
|
self._set_args_schema()
|
||||||
return StructuredTool(
|
return CrewStructuredTool(
|
||||||
name=self.name,
|
name=self.name,
|
||||||
description=self.description,
|
description=self.description,
|
||||||
args_schema=self.args_schema,
|
args_schema=self.args_schema,
|
||||||
@@ -73,10 +75,10 @@ class BaseTool(BaseModel, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_langchain(cls, tool: StructuredTool) -> "BaseTool":
|
def from_langchain(cls, tool: CrewStructuredTool) -> "BaseTool":
|
||||||
if cls == Tool:
|
if cls == Tool:
|
||||||
if tool.func is None:
|
if tool.func is None:
|
||||||
raise ValueError("StructuredTool must have a callable 'func'")
|
raise ValueError("CrewStructuredTool must have a callable 'func'")
|
||||||
return Tool(
|
return Tool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
@@ -142,9 +144,9 @@ class Tool(BaseTool):
|
|||||||
|
|
||||||
|
|
||||||
def to_langchain(
|
def to_langchain(
|
||||||
tools: list[BaseTool | StructuredTool],
|
tools: list[BaseTool | CrewStructuredTool],
|
||||||
) -> list[StructuredTool]:
|
) -> list[CrewStructuredTool]:
|
||||||
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
|
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||||
|
|
||||||
|
|
||||||
def tool(*args):
|
def tool(*args):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
|
|
||||||
|
|
||||||
class CacheTools(BaseModel):
|
class CacheTools(BaseModel):
|
||||||
@@ -13,9 +14,7 @@ class CacheTools(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def tool(self):
|
def tool(self):
|
||||||
from langchain.tools import StructuredTool
|
return CrewStructuredTool.from_function(
|
||||||
|
|
||||||
return StructuredTool.from_function(
|
|
||||||
func=self.hit_cache,
|
func=self.hit_cache,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
description="Reads directly from the cache",
|
description="Reads directly from the cache",
|
||||||
|
|||||||
249
src/crewai/tools/structured_tool.py
Normal file
249
src/crewai/tools/structured_tool.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, create_model
|
||||||
|
|
||||||
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
|
class CrewStructuredTool:
|
||||||
|
"""A structured tool that can operate on any number of inputs.
|
||||||
|
|
||||||
|
This tool replaces LangChain's StructuredTool with a custom implementation
|
||||||
|
that integrates better with CrewAI's ecosystem.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
args_schema: type[BaseModel],
|
||||||
|
func: Callable[..., Any],
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the structured tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the tool
|
||||||
|
description: A description of what the tool does
|
||||||
|
args_schema: The pydantic model for the tool's arguments
|
||||||
|
func: The function to run when the tool is called
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.args_schema = args_schema
|
||||||
|
self.func = func
|
||||||
|
self._logger = Logger()
|
||||||
|
|
||||||
|
# Validate the function signature matches the schema
|
||||||
|
self._validate_function_signature()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_function(
|
||||||
|
cls,
|
||||||
|
func: Callable,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
return_direct: bool = False,
|
||||||
|
args_schema: Optional[type[BaseModel]] = None,
|
||||||
|
infer_schema: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> CrewStructuredTool:
|
||||||
|
"""Create a tool from a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to create a tool from
|
||||||
|
name: The name of the tool. Defaults to the function name
|
||||||
|
description: The description of the tool. Defaults to the function docstring
|
||||||
|
return_direct: Whether to return the output directly
|
||||||
|
args_schema: Optional schema for the function arguments
|
||||||
|
infer_schema: Whether to infer the schema from the function signature
|
||||||
|
**kwargs: Additional arguments to pass to the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A CrewStructuredTool instance
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> def add(a: int, b: int) -> int:
|
||||||
|
... '''Add two numbers'''
|
||||||
|
... return a + b
|
||||||
|
>>> tool = CrewStructuredTool.from_function(add)
|
||||||
|
"""
|
||||||
|
name = name or func.__name__
|
||||||
|
description = description or inspect.getdoc(func)
|
||||||
|
|
||||||
|
if description is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Function {name} must have a docstring if description not provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up the description
|
||||||
|
description = textwrap.dedent(description).strip()
|
||||||
|
|
||||||
|
if args_schema is not None:
|
||||||
|
# Use provided schema
|
||||||
|
schema = args_schema
|
||||||
|
elif infer_schema:
|
||||||
|
# Infer schema from function signature
|
||||||
|
schema = cls._create_schema_from_function(name, func)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either args_schema must be provided or infer_schema must be True."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
args_schema=schema,
|
||||||
|
func=func,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_schema_from_function(
|
||||||
|
name: str,
|
||||||
|
func: Callable,
|
||||||
|
) -> type[BaseModel]:
|
||||||
|
"""Create a Pydantic schema from a function's signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name to use for the schema
|
||||||
|
func: The function to create a schema from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydantic model class
|
||||||
|
"""
|
||||||
|
# Get function signature
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
|
# Get type hints
|
||||||
|
type_hints = inspect.get_type_hints(func)
|
||||||
|
|
||||||
|
# Create field definitions
|
||||||
|
fields = {}
|
||||||
|
for param_name, param in sig.parameters.items():
|
||||||
|
# Skip self/cls for methods
|
||||||
|
if param_name in ("self", "cls"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get type annotation
|
||||||
|
annotation = type_hints.get(param_name, Any)
|
||||||
|
|
||||||
|
# Get default value
|
||||||
|
default = ... if param.default == param.empty else param.default
|
||||||
|
|
||||||
|
# Add field
|
||||||
|
fields[param_name] = (annotation, Field(default=default))
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
schema_name = f"{name.title()}Schema"
|
||||||
|
return create_model(schema_name, **fields)
|
||||||
|
|
||||||
|
def _validate_function_signature(self) -> None:
|
||||||
|
"""Validate that the function signature matches the args schema."""
|
||||||
|
sig = inspect.signature(self.func)
|
||||||
|
schema_fields = self.args_schema.model_fields
|
||||||
|
|
||||||
|
# Check required parameters
|
||||||
|
for param_name, param in sig.parameters.items():
|
||||||
|
# Skip self/cls for methods
|
||||||
|
if param_name in ("self", "cls"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
if param_name not in schema_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Required function parameter '{param_name}' "
|
||||||
|
f"not found in args_schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
field = schema_fields[param_name]
|
||||||
|
if field.default == ... and field.default_factory is None:
|
||||||
|
# Parameter is required in both function and schema
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Function parameter '{param_name}' is required but has a "
|
||||||
|
f"default value in the schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||||
|
"""Parse and validate the input arguments against the schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_args: The raw arguments to parse, either as a string or dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated arguments as a dictionary
|
||||||
|
"""
|
||||||
|
if isinstance(raw_args, str):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
raw_args = json.loads(raw_args)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
validated_args = self.args_schema.model_validate(raw_args)
|
||||||
|
return validated_args.model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Arguments validation failed: {e}")
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self,
|
||||||
|
input: Union[str, dict],
|
||||||
|
config: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Asynchronously invoke the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input arguments
|
||||||
|
config: Optional configuration
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the tool execution
|
||||||
|
"""
|
||||||
|
parsed_args = self._parse_args(input)
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(self.func):
|
||||||
|
return await self.func(**parsed_args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Run sync functions in a thread pool
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None, lambda: self.func(**parsed_args, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
input: Union[str, dict],
|
||||||
|
config: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Synchronously invoke the tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input arguments
|
||||||
|
config: Optional configuration
|
||||||
|
**kwargs: Additional keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the tool execution
|
||||||
|
"""
|
||||||
|
parsed_args = self._parse_args(input)
|
||||||
|
return self.func(**parsed_args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args(self) -> dict:
|
||||||
|
"""Get the tool's input arguments schema."""
|
||||||
|
return self.args_schema.model_json_schema()["properties"]
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"CrewStructuredTool(name='{self.name}', description='{self.description}')"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user