mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 08:38:15 +00:00
Squashed 'packages/tools/' changes from 78317b9c..0b3f00e6
0b3f00e6 chore: update project version to 0.73.0 and revise uv.lock dependencies (#455) ad19b074 feat: replace embedchain with native crewai adapter (#451) git-subtree-dir: packages/tools git-subtree-split: 0b3f00e67c0dae24d188c292dc99759fd1c841f7
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedCodeDocsSearchToolSchema(BaseModel):
|
||||
@@ -42,15 +38,15 @@ class CodeDocsSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docs_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docs_url, data_type=DataType.DOCS_SITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docs_url: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if docs_url is not None:
|
||||
self.add(docs_url)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedCSVSearchToolSchema(BaseModel):
|
||||
@@ -42,15 +38,16 @@ class CSVSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, csv: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(csv, data_type=DataType.CSV)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
csv: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if csv is not None:
|
||||
self.add(csv)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedDirectorySearchToolSchema(BaseModel):
|
||||
@@ -34,8 +29,6 @@ class DirectorySearchTool(RagTool):
|
||||
args_schema: Type[BaseModel] = DirectorySearchToolSchema
|
||||
|
||||
def __init__(self, directory: Optional[str] = None, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
@@ -44,16 +37,15 @@ class DirectorySearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, directory: str) -> None:
|
||||
super().add(
|
||||
directory,
|
||||
loader=DirectoryLoader(config=dict(recursive=True)),
|
||||
)
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
directory: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if directory is not None:
|
||||
self.add(directory)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedDOCXSearchToolSchema(BaseModel):
|
||||
@@ -48,15 +44,15 @@ class DOCXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, docx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(docx, data_type=DataType.DOCX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
docx: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Any:
|
||||
if docx is not None:
|
||||
self.add(docx)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import List, Optional, Type, Any
|
||||
from typing import List, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.github import GithubLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedGithubSearchToolSchema(BaseModel):
|
||||
@@ -42,7 +37,6 @@ class GithubSearchTool(RagTool):
|
||||
default_factory=lambda: ["code", "repo", "pr", "issue"],
|
||||
description="Content types you want to be included search, options: [code, repo, pr, issue]",
|
||||
)
|
||||
_loader: Any | None = PrivateAttr(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -50,10 +44,7 @@ class GithubSearchTool(RagTool):
|
||||
content_types: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
self._loader = GithubLoader(config={"token": self.gh_token})
|
||||
|
||||
if github_repo and content_types:
|
||||
self.add(repo=github_repo, content_types=content_types)
|
||||
@@ -67,11 +58,10 @@ class GithubSearchTool(RagTool):
|
||||
content_types: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
content_types = content_types or self.content_types
|
||||
|
||||
super().add(
|
||||
f"repo:{repo} type:{','.join(content_types)}",
|
||||
data_type="github",
|
||||
loader=self._loader,
|
||||
f"https://github.com/{repo}",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": content_types, "gh_token": self.gh_token}
|
||||
)
|
||||
|
||||
def _run(
|
||||
@@ -79,10 +69,12 @@ class GithubSearchTool(RagTool):
|
||||
search_query: str,
|
||||
github_repo: Optional[str] = None,
|
||||
content_types: Optional[List[str]] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if github_repo:
|
||||
self.add(
|
||||
repo=github_repo,
|
||||
content_types=content_types,
|
||||
)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -41,7 +41,9 @@ class JSONSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
json_path: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if json_path is not None:
|
||||
self.add(json_path)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -2,13 +2,9 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedMDXSearchToolSchema(BaseModel):
|
||||
@@ -42,15 +38,15 @@ class MDXSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, mdx: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(mdx, data_type=DataType.MDX)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
mdx: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if mdx is not None:
|
||||
self.add(mdx)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.mysql import MySQLLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class MySQLSearchToolSchema(BaseModel):
|
||||
@@ -27,12 +22,8 @@ class MySQLSearchTool(RagTool):
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
kwargs["data_type"] = "mysql"
|
||||
kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri))
|
||||
self.add(table_name)
|
||||
self.add(table_name, data_type=DataType.MYSQL, metadata={"db_uri": self.db_uri})
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
|
||||
@@ -46,6 +37,8 @@ class MySQLSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -2,13 +2,8 @@ from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedPDFSearchToolSchema(BaseModel):
|
||||
@@ -41,15 +36,15 @@ class PDFSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
pdf: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
return super()._run(query=query)
|
||||
return super()._run(query=query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Type
|
||||
|
||||
try:
|
||||
from embedchain.loaders.postgres import PostgresLoader
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class PGSearchToolSchema(BaseModel):
|
||||
@@ -27,12 +22,8 @@ class PGSearchTool(RagTool):
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
def __init__(self, table_name: str, **kwargs):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**kwargs)
|
||||
kwargs["data_type"] = "postgres"
|
||||
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
|
||||
self.add(table_name)
|
||||
self.add(table_name, data_type=DataType.POSTGRES, metadata={"db_uri": self.db_uri})
|
||||
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
|
||||
self._generate_description()
|
||||
|
||||
@@ -46,6 +37,8 @@ class PGSearchTool(RagTool):
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
return super()._run(query=search_query, **kwargs)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit, **kwargs)
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import portalocker
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def query(self, question: str) -> str:
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""Query the knowledge base with a question and return the answer."""
|
||||
|
||||
@abstractmethod
|
||||
@@ -25,7 +30,12 @@ class Adapter(BaseModel, ABC):
|
||||
|
||||
class RagTool(BaseTool):
|
||||
class _AdapterPlaceholder(Adapter):
|
||||
def query(self, question: str) -> str:
|
||||
def query(
|
||||
self,
|
||||
question: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -34,28 +44,149 @@ class RagTool(BaseTool):
|
||||
name: str = "Knowledge base"
|
||||
description: str = "A knowledge base that can be used to answer questions."
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: dict[str, Any] | None = None
|
||||
config: Any | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
try:
|
||||
from embedchain import App
|
||||
except ImportError:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
parsed_config = self._parse_config(self.config)
|
||||
|
||||
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
||||
app = App.from_config(config=self.config) if self.config else App()
|
||||
|
||||
self.adapter = EmbedchainAdapter(
|
||||
embedchain_app=app, summarize=self.summarize
|
||||
self.adapter = CrewAIRagAdapter(
|
||||
collection_name="rag_tool_collection",
|
||||
summarize=self.summarize,
|
||||
similarity_threshold=self.similarity_threshold,
|
||||
limit=self.limit,
|
||||
config=parsed_config,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _parse_config(self, config: Any) -> Any:
|
||||
"""Parse complex config format to extract provider-specific config.
|
||||
|
||||
Raises:
|
||||
ValueError: If the config format is invalid or uses unsupported providers.
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
if isinstance(config, dict) and "provider" in config:
|
||||
return config
|
||||
|
||||
if isinstance(config, dict):
|
||||
if "vectordb" in config:
|
||||
vectordb_config = config["vectordb"]
|
||||
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
||||
provider = vectordb_config["provider"]
|
||||
provider_config = vectordb_config.get("config", {})
|
||||
|
||||
supported_providers = ["chromadb", "qdrant"]
|
||||
if provider not in supported_providers:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
||||
)
|
||||
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, provider
|
||||
)
|
||||
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, "chromadb"
|
||||
)
|
||||
|
||||
return self._create_provider_config("chromadb", {}, embedding_function)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
||||
"""Create embedding function for the specified vector database provider."""
|
||||
embedding_provider = embedding_config.get("provider")
|
||||
embedding_model_config = embedding_config.get("config", {}).copy()
|
||||
|
||||
if "model" in embedding_model_config:
|
||||
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
||||
|
||||
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
||||
|
||||
if embedding_provider == "openai" and "api_key" not in factory_config:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
factory_config["api_key"] = api_key
|
||||
|
||||
print(f"Creating embedding function with config: {factory_config}")
|
||||
|
||||
if provider == "chromadb":
|
||||
embedding_func = get_embedding_function(factory_config)
|
||||
print(f"Created embedding function: {embedding_func}")
|
||||
print(f"Embedding function type: {type(embedding_func)}")
|
||||
return embedding_func
|
||||
|
||||
elif provider == "qdrant":
|
||||
chromadb_func = get_embedding_function(factory_config)
|
||||
|
||||
def qdrant_embed_fn(text: str) -> list[float]:
|
||||
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
||||
|
||||
Args:
|
||||
text: The input text to embed.
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding.
|
||||
"""
|
||||
embeddings = chromadb_func([text])
|
||||
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
||||
|
||||
return cast(Any, qdrant_embed_fn)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _create_provider_config(
|
||||
provider: str, provider_config: dict, embedding_function: Any
|
||||
) -> Any:
|
||||
"""Create proper provider config object."""
|
||||
if provider == "chromadb":
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return ChromaDBConfig(**config_kwargs)
|
||||
|
||||
elif provider == "qdrant":
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return QdrantConfig(**config_kwargs)
|
||||
|
||||
return None
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -66,5 +197,13 @@ class RagTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
return f"Relevant Content:\n{self.adapter.query(query)}"
|
||||
threshold = (
|
||||
similarity_threshold
|
||||
if similarity_threshold is not None
|
||||
else self.similarity_threshold
|
||||
)
|
||||
result_limit = limit if limit is not None else self.limit
|
||||
return f"Relevant Content:\n{self.adapter.query(query, similarity_threshold=threshold, limit=result_limit)}"
|
||||
|
||||
@@ -39,7 +39,9 @@ class TXTSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
txt: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedWebsiteSearchToolSchema(BaseModel):
|
||||
@@ -44,15 +39,15 @@ class WebsiteSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, website: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(website, data_type=DataType.WEB_PAGE)
|
||||
super().add(website, data_type=DataType.WEBSITE)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
website: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if website is not None:
|
||||
self.add(website)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -39,7 +39,9 @@ class XMLSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
xml: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if xml is not None:
|
||||
self.add(xml)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedYoutubeChannelSearchToolSchema(BaseModel):
|
||||
@@ -55,7 +50,9 @@ class YoutubeChannelSearchTool(RagTool):
|
||||
self,
|
||||
search_query: str,
|
||||
youtube_channel_handle: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if youtube_channel_handle is not None:
|
||||
self.add(youtube_channel_handle)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
try:
|
||||
from embedchain.models.data_type import DataType
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..rag.rag_tool import RagTool
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
class FixedYoutubeVideoSearchToolSchema(BaseModel):
|
||||
@@ -44,15 +40,15 @@ class YoutubeVideoSearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, youtube_video_url: str) -> None:
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().add(youtube_video_url, data_type=DataType.YOUTUBE_VIDEO)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
search_query: str,
|
||||
youtube_video_url: Optional[str] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
if youtube_video_url is not None:
|
||||
self.add(youtube_video_url)
|
||||
return super()._run(query=search_query)
|
||||
return super()._run(query=search_query, similarity_threshold=similarity_threshold, limit=limit)
|
||||
|
||||
Reference in New Issue
Block a user