mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 05:48:14 +00:00
79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
from typing import Any, Optional, Type, cast
|
|
|
|
from crewai.tools import BaseTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class LlamaIndexTool(BaseTool):
|
|
"""Tool to wrap LlamaIndex tools/query engines."""
|
|
|
|
llama_index_tool: Any
|
|
|
|
def _run(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Run tool."""
|
|
from llama_index.core.tools import BaseTool as LlamaBaseTool
|
|
|
|
tool = cast(LlamaBaseTool, self.llama_index_tool)
|
|
return tool(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def from_tool(cls, tool: Any, **kwargs: Any) -> "LlamaIndexTool":
|
|
from llama_index.core.tools import BaseTool as LlamaBaseTool
|
|
|
|
if not isinstance(tool, LlamaBaseTool):
|
|
raise ValueError(f"Expected a LlamaBaseTool, got {type(tool)}")
|
|
tool = cast(LlamaBaseTool, tool)
|
|
|
|
if tool.metadata.fn_schema is None:
|
|
raise ValueError(
|
|
"The LlamaIndex tool does not have an fn_schema specified."
|
|
)
|
|
args_schema = cast(Type[BaseModel], tool.metadata.fn_schema)
|
|
|
|
return cls(
|
|
name=tool.metadata.name,
|
|
description=tool.metadata.description,
|
|
args_schema=args_schema,
|
|
llama_index_tool=tool,
|
|
**kwargs,
|
|
)
|
|
|
|
@classmethod
|
|
def from_query_engine(
|
|
cls,
|
|
query_engine: Any,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
return_direct: bool = False,
|
|
**kwargs: Any,
|
|
) -> "LlamaIndexTool":
|
|
from llama_index.core.query_engine import BaseQueryEngine
|
|
from llama_index.core.tools import QueryEngineTool
|
|
|
|
if not isinstance(query_engine, BaseQueryEngine):
|
|
raise ValueError(f"Expected a BaseQueryEngine, got {type(query_engine)}")
|
|
|
|
# NOTE: by default the schema expects an `input` variable. However this
|
|
# confuses crewAI so we are renaming to `query`.
|
|
class QueryToolSchema(BaseModel):
|
|
"""Schema for query tool."""
|
|
|
|
query: str = Field(..., description="Search query for the query tool.")
|
|
|
|
# NOTE: setting `resolve_input_errors` to True is important because the schema expects `input` but we are using `query`
|
|
query_engine_tool = QueryEngineTool.from_defaults(
|
|
query_engine,
|
|
name=name,
|
|
description=description,
|
|
return_direct=return_direct,
|
|
resolve_input_errors=True,
|
|
)
|
|
# HACK: we are replacing the schema with our custom schema
|
|
query_engine_tool.metadata.fn_schema = QueryToolSchema
|
|
|
|
return cls.from_tool(query_engine_tool, **kwargs)
|