mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 07:38:14 +00:00
0b3f00e6 chore: update project version to 0.73.0 and revise uv.lock dependencies (#455) ad19b074 feat: replace embedchain with native crewai adapter (#451) git-subtree-dir: packages/tools git-subtree-split: 0b3f00e67c0dae24d188c292dc99759fd1c841f7
215 lines
9.4 KiB
Python
215 lines
9.4 KiB
Python
"""Adapter for CrewAI's native RAG system."""
|
|
|
|
from typing import Any, TypedDict, TypeAlias
|
|
from typing_extensions import Unpack
|
|
from pathlib import Path
|
|
import hashlib
|
|
|
|
from pydantic import Field, PrivateAttr
|
|
from crewai.rag.config.utils import get_rag_client
|
|
from crewai.rag.config.types import RagConfigType
|
|
from crewai.rag.types import BaseRecord, SearchResult
|
|
from crewai.rag.core.base_client import BaseClient
|
|
from crewai.rag.factory import create_client
|
|
|
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
|
from crewai_tools.rag.data_types import DataType
|
|
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
|
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
|
|
|
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
|
|
|
class AddDocumentParams(TypedDict, total=False):
|
|
"""Parameters for adding documents to the RAG system."""
|
|
data_type: DataType
|
|
metadata: dict[str, Any]
|
|
website: str
|
|
url: str
|
|
file_path: str | Path
|
|
github_url: str
|
|
youtube_url: str
|
|
directory_path: str | Path
|
|
|
|
|
|
class CrewAIRagAdapter(Adapter):
|
|
"""Adapter that uses CrewAI's native RAG system.
|
|
|
|
Supports custom vector database configuration through the config parameter.
|
|
"""
|
|
|
|
collection_name: str = "default"
|
|
summarize: bool = False
|
|
similarity_threshold: float = 0.6
|
|
limit: int = 5
|
|
config: RagConfigType | None = None
|
|
_client: BaseClient | None = PrivateAttr(default=None)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""Initialize the CrewAI RAG client after model initialization."""
|
|
if self.config is not None:
|
|
self._client = create_client(self.config)
|
|
else:
|
|
self._client = get_rag_client()
|
|
self._client.get_or_create_collection(collection_name=self.collection_name)
|
|
|
|
def query(self, question: str, similarity_threshold: float | None = None, limit: int | None = None) -> str:
|
|
"""Query the knowledge base with a question.
|
|
|
|
Args:
|
|
question: The question to ask
|
|
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
|
limit: Maximum number of results to return (default: 5)
|
|
|
|
Returns:
|
|
Relevant content from the knowledge base
|
|
"""
|
|
search_limit = limit if limit is not None else self.limit
|
|
search_threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold
|
|
|
|
results: list[SearchResult] = self._client.search(
|
|
collection_name=self.collection_name,
|
|
query=question,
|
|
limit=search_limit,
|
|
score_threshold=search_threshold
|
|
)
|
|
|
|
if not results:
|
|
return "No relevant content found."
|
|
|
|
contents: list[str] = []
|
|
for result in results:
|
|
content: str = result.get("content", "")
|
|
if content:
|
|
contents.append(content)
|
|
|
|
return "\n\n".join(contents)
|
|
|
|
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
|
"""Add content to the knowledge base.
|
|
|
|
This method handles various input types and converts them to documents
|
|
for the vector database. It supports the data_type parameter for
|
|
compatibility with existing tools.
|
|
|
|
Args:
|
|
*args: Content items to add (strings, paths, or document dicts)
|
|
**kwargs: Additional parameters including data_type, metadata, etc.
|
|
"""
|
|
from crewai_tools.rag.data_types import DataTypes, DataType
|
|
from crewai_tools.rag.source_content import SourceContent
|
|
from crewai_tools.rag.base_loader import LoaderResult
|
|
import os
|
|
|
|
documents: list[BaseRecord] = []
|
|
data_type: DataType | None = kwargs.get("data_type")
|
|
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
|
|
|
for arg in args:
|
|
source_ref: str
|
|
if isinstance(arg, dict):
|
|
source_ref = str(arg.get("source", arg.get("content", "")))
|
|
else:
|
|
source_ref = str(arg)
|
|
|
|
if not data_type:
|
|
data_type = DataTypes.from_content(source_ref)
|
|
|
|
if data_type == DataType.DIRECTORY:
|
|
if not os.path.isdir(source_ref):
|
|
raise ValueError(f"Directory does not exist: {source_ref}")
|
|
|
|
# Define binary and non-text file extensions to skip
|
|
binary_extensions = {'.pyc', '.pyo', '.png', '.jpg', '.jpeg', '.gif',
|
|
'.bmp', '.ico', '.svg', '.webp', '.pdf', '.zip',
|
|
'.tar', '.gz', '.bz2', '.7z', '.rar', '.exe',
|
|
'.dll', '.so', '.dylib', '.bin', '.dat', '.db',
|
|
'.sqlite', '.class', '.jar', '.war', '.ear'}
|
|
|
|
for root, dirs, files in os.walk(source_ref):
|
|
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
|
|
|
for filename in files:
|
|
if filename.startswith('.'):
|
|
continue
|
|
|
|
# Skip binary files based on extension
|
|
file_ext = os.path.splitext(filename)[1].lower()
|
|
if file_ext in binary_extensions:
|
|
continue
|
|
|
|
# Skip __pycache__ directories
|
|
if '__pycache__' in root:
|
|
continue
|
|
|
|
file_path: str = os.path.join(root, filename)
|
|
try:
|
|
file_data_type: DataType = DataTypes.from_content(file_path)
|
|
file_loader = file_data_type.get_loader()
|
|
file_chunker = file_data_type.get_chunker()
|
|
|
|
file_source = SourceContent(file_path)
|
|
file_result: LoaderResult = file_loader.load(file_source)
|
|
|
|
file_chunks = file_chunker.chunk(file_result.content)
|
|
|
|
for chunk_idx, file_chunk in enumerate(file_chunks):
|
|
file_metadata: dict[str, Any] = base_metadata.copy()
|
|
file_metadata.update(file_result.metadata)
|
|
file_metadata["data_type"] = str(file_data_type)
|
|
file_metadata["file_path"] = file_path
|
|
file_metadata["chunk_index"] = chunk_idx
|
|
file_metadata["total_chunks"] = len(file_chunks)
|
|
|
|
if isinstance(arg, dict):
|
|
file_metadata.update(arg.get("metadata", {}))
|
|
|
|
chunk_id = hashlib.sha256(f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()).hexdigest()
|
|
|
|
documents.append({
|
|
"doc_id": chunk_id,
|
|
"content": file_chunk,
|
|
"metadata": sanitize_metadata_for_chromadb(file_metadata)
|
|
})
|
|
except Exception:
|
|
# Silently skip files that can't be processed
|
|
continue
|
|
else:
|
|
metadata: dict[str, Any] = base_metadata.copy()
|
|
|
|
if data_type in [DataType.PDF_FILE, DataType.TEXT_FILE, DataType.DOCX,
|
|
DataType.CSV, DataType.JSON, DataType.XML, DataType.MDX]:
|
|
if not os.path.isfile(source_ref):
|
|
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
|
|
|
loader = data_type.get_loader()
|
|
chunker = data_type.get_chunker()
|
|
|
|
source_content = SourceContent(source_ref)
|
|
loader_result: LoaderResult = loader.load(source_content)
|
|
|
|
chunks = chunker.chunk(loader_result.content)
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_metadata: dict[str, Any] = metadata.copy()
|
|
chunk_metadata.update(loader_result.metadata)
|
|
chunk_metadata["data_type"] = str(data_type)
|
|
chunk_metadata["chunk_index"] = i
|
|
chunk_metadata["total_chunks"] = len(chunks)
|
|
chunk_metadata["source"] = source_ref
|
|
|
|
if isinstance(arg, dict):
|
|
chunk_metadata.update(arg.get("metadata", {}))
|
|
|
|
chunk_id = hashlib.sha256(f"{loader_result.doc_id}_{i}_{chunk}".encode()).hexdigest()
|
|
|
|
documents.append({
|
|
"doc_id": chunk_id,
|
|
"content": chunk,
|
|
"metadata": sanitize_metadata_for_chromadb(chunk_metadata)
|
|
})
|
|
|
|
if documents:
|
|
self._client.add_documents(
|
|
collection_name=self.collection_name,
|
|
documents=documents
|
|
) |