mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 09:38:17 +00:00
Squashed 'packages/tools/' changes from 78317b9c..0b3f00e6
0b3f00e6 chore: update project version to 0.73.0 and revise uv.lock dependencies (#455) ad19b074 feat: replace embedchain with native crewai adapter (#451) git-subtree-dir: packages/tools git-subtree-split: 0b3f00e67c0dae24d188c292dc99759fd1c841f7
This commit is contained in:
215
crewai_tools/adapters/crewai_rag_adapter.py
Normal file
215
crewai_tools/adapters/crewai_rag_adapter.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
from typing import Any, TypedDict, TypeAlias
|
||||
from typing_extensions import Unpack
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
from pydantic import Field, PrivateAttr
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
"""Adapter that uses CrewAI's native RAG system.
|
||||
|
||||
Supports custom vector database configuration through the config parameter.
|
||||
"""
|
||||
|
||||
collection_name: str = "default"
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
config: RagConfigType | None = None
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize the CrewAI RAG client after model initialization."""
|
||||
if self.config is not None:
|
||||
self._client = create_client(self.config)
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||
|
||||
def query(self, question: str, similarity_threshold: float | None = None, limit: int | None = None) -> str:
|
||||
"""Query the knowledge base with a question.
|
||||
|
||||
Args:
|
||||
question: The question to ask
|
||||
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
|
||||
Returns:
|
||||
Relevant content from the knowledge base
|
||||
"""
|
||||
search_limit = limit if limit is not None else self.limit
|
||||
search_threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold
|
||||
|
||||
results: list[SearchResult] = self._client.search(
|
||||
collection_name=self.collection_name,
|
||||
query=question,
|
||||
limit=search_limit,
|
||||
score_threshold=search_threshold
|
||||
)
|
||||
|
||||
if not results:
|
||||
return "No relevant content found."
|
||||
|
||||
contents: list[str] = []
|
||||
for result in results:
|
||||
content: str = result.get("content", "")
|
||||
if content:
|
||||
contents.append(content)
|
||||
|
||||
return "\n\n".join(contents)
|
||||
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
"""
|
||||
from crewai_tools.rag.data_types import DataTypes, DataType
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.rag.base_loader import LoaderResult
|
||||
import os
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
for arg in args:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
else:
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
if not os.path.isdir(source_ref):
|
||||
raise ValueError(f"Directory does not exist: {source_ref}")
|
||||
|
||||
# Define binary and non-text file extensions to skip
|
||||
binary_extensions = {'.pyc', '.pyo', '.png', '.jpg', '.jpeg', '.gif',
|
||||
'.bmp', '.ico', '.svg', '.webp', '.pdf', '.zip',
|
||||
'.tar', '.gz', '.bz2', '.7z', '.rar', '.exe',
|
||||
'.dll', '.so', '.dylib', '.bin', '.dat', '.db',
|
||||
'.sqlite', '.class', '.jar', '.war', '.ear'}
|
||||
|
||||
for root, dirs, files in os.walk(source_ref):
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
|
||||
for filename in files:
|
||||
if filename.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip binary files based on extension
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if file_ext in binary_extensions:
|
||||
continue
|
||||
|
||||
# Skip __pycache__ directories
|
||||
if '__pycache__' in root:
|
||||
continue
|
||||
|
||||
file_path: str = os.path.join(root, filename)
|
||||
try:
|
||||
file_data_type: DataType = DataTypes.from_content(file_path)
|
||||
file_loader = file_data_type.get_loader()
|
||||
file_chunker = file_data_type.get_chunker()
|
||||
|
||||
file_source = SourceContent(file_path)
|
||||
file_result: LoaderResult = file_loader.load(file_source)
|
||||
|
||||
file_chunks = file_chunker.chunk(file_result.content)
|
||||
|
||||
for chunk_idx, file_chunk in enumerate(file_chunks):
|
||||
file_metadata: dict[str, Any] = base_metadata.copy()
|
||||
file_metadata.update(file_result.metadata)
|
||||
file_metadata["data_type"] = str(file_data_type)
|
||||
file_metadata["file_path"] = file_path
|
||||
file_metadata["chunk_index"] = chunk_idx
|
||||
file_metadata["total_chunks"] = len(file_chunks)
|
||||
|
||||
if isinstance(arg, dict):
|
||||
file_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()).hexdigest()
|
||||
|
||||
documents.append({
|
||||
"doc_id": chunk_id,
|
||||
"content": file_chunk,
|
||||
"metadata": sanitize_metadata_for_chromadb(file_metadata)
|
||||
})
|
||||
except Exception:
|
||||
# Silently skip files that can't be processed
|
||||
continue
|
||||
else:
|
||||
metadata: dict[str, Any] = base_metadata.copy()
|
||||
|
||||
if data_type in [DataType.PDF_FILE, DataType.TEXT_FILE, DataType.DOCX,
|
||||
DataType.CSV, DataType.JSON, DataType.XML, DataType.MDX]:
|
||||
if not os.path.isfile(source_ref):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
|
||||
loader = data_type.get_loader()
|
||||
chunker = data_type.get_chunker()
|
||||
|
||||
source_content = SourceContent(source_ref)
|
||||
loader_result: LoaderResult = loader.load(source_content)
|
||||
|
||||
chunks = chunker.chunk(loader_result.content)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_metadata: dict[str, Any] = metadata.copy()
|
||||
chunk_metadata.update(loader_result.metadata)
|
||||
chunk_metadata["data_type"] = str(data_type)
|
||||
chunk_metadata["chunk_index"] = i
|
||||
chunk_metadata["total_chunks"] = len(chunks)
|
||||
chunk_metadata["source"] = source_ref
|
||||
|
||||
if isinstance(arg, dict):
|
||||
chunk_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(f"{loader_result.doc_id}_{i}_{chunk}".encode()).hexdigest()
|
||||
|
||||
documents.append({
|
||||
"doc_id": chunk_id,
|
||||
"content": chunk,
|
||||
"metadata": sanitize_metadata_for_chromadb(chunk_metadata)
|
||||
})
|
||||
|
||||
if documents:
|
||||
self._client.add_documents(
|
||||
collection_name=self.collection_name,
|
||||
documents=documents
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class EmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize)
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
@@ -1,41 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
try:
|
||||
from embedchain import App
|
||||
EMBEDCHAIN_AVAILABLE = True
|
||||
except ImportError:
|
||||
EMBEDCHAIN_AVAILABLE = False
|
||||
|
||||
|
||||
class PDFEmbedchainAdapter(Adapter):
|
||||
embedchain_app: Any # Will be App when embedchain is available
|
||||
summarize: bool = False
|
||||
src: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if not EMBEDCHAIN_AVAILABLE:
|
||||
raise ImportError("embedchain is not installed. Please install it with `pip install crewai-tools[embedchain]`")
|
||||
super().__init__(**data)
|
||||
|
||||
def query(self, question: str) -> str:
|
||||
where = (
|
||||
{"app_id": self.embedchain_app.config.id, "source": self.src}
|
||||
if self.src
|
||||
else None
|
||||
)
|
||||
result, sources = self.embedchain_app.query(
|
||||
question, citations=True, dry_run=(not self.summarize), where=where
|
||||
)
|
||||
if self.summarize:
|
||||
return result
|
||||
return "\n\n".join([source[0] for source in sources])
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.src = args[0] if args else None
|
||||
self.embedchain_app.add(*args, **kwargs)
|
||||
Reference in New Issue
Block a user