Files
crewAI/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py
2025-10-02 11:06:38 -04:00

270 lines
9.8 KiB
Python

"""Adapter for CrewAI's native RAG system."""
import hashlib
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from pydantic import PrivateAttr
from typing_extensions import Unpack
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
from crewai_tools.tools.rag.rag_tool import Adapter
ContentItem: TypeAlias = str | Path | dict[str, Any]
class AddDocumentParams(TypedDict, total=False):
"""Parameters for adding documents to the RAG system."""
data_type: DataType
metadata: dict[str, Any]
website: str
url: str
file_path: str | Path
github_url: str
youtube_url: str
directory_path: str | Path
class CrewAIRagAdapter(Adapter):
"""Adapter that uses CrewAI's native RAG system.
Supports custom vector database configuration through the config parameter.
"""
collection_name: str = "default"
summarize: bool = False
similarity_threshold: float = 0.6
limit: int = 5
config: RagConfigType | None = None
_client: BaseClient | None = PrivateAttr(default=None)
def model_post_init(self, __context: Any) -> None:
"""Initialize the CrewAI RAG client after model initialization."""
if self.config is not None:
self._client = create_client(self.config)
else:
self._client = get_rag_client()
self._client.get_or_create_collection(collection_name=self.collection_name)
def query(
self,
question: str,
similarity_threshold: float | None = None,
limit: int | None = None,
) -> str:
"""Query the knowledge base with a question.
Args:
question: The question to ask
similarity_threshold: Minimum similarity score for results (default: 0.6)
limit: Maximum number of results to return (default: 5)
Returns:
Relevant content from the knowledge base
"""
search_limit = limit if limit is not None else self.limit
search_threshold = (
similarity_threshold
if similarity_threshold is not None
else self.similarity_threshold
)
results: list[SearchResult] = self._client.search(
collection_name=self.collection_name,
query=question,
limit=search_limit,
score_threshold=search_threshold,
)
if not results:
return "No relevant content found."
contents: list[str] = []
for result in results:
content: str = result.get("content", "")
if content:
contents.append(content)
return "\n\n".join(contents)
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
"""Add content to the knowledge base.
This method handles various input types and converts them to documents
for the vector database. It supports the data_type parameter for
compatibility with existing tools.
Args:
*args: Content items to add (strings, paths, or document dicts)
**kwargs: Additional parameters including data_type, metadata, etc.
"""
import os
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.data_types import DataType, DataTypes
from crewai_tools.rag.source_content import SourceContent
documents: list[BaseRecord] = []
data_type: DataType | None = kwargs.get("data_type")
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
for arg in args:
source_ref: str
if isinstance(arg, dict):
source_ref = str(arg.get("source", arg.get("content", "")))
else:
source_ref = str(arg)
if not data_type:
data_type = DataTypes.from_content(source_ref)
if data_type == DataType.DIRECTORY:
if not os.path.isdir(source_ref):
raise ValueError(f"Directory does not exist: {source_ref}")
# Define binary and non-text file extensions to skip
binary_extensions = {
".pyc",
".pyo",
".png",
".jpg",
".jpeg",
".gif",
".bmp",
".ico",
".svg",
".webp",
".pdf",
".zip",
".tar",
".gz",
".bz2",
".7z",
".rar",
".exe",
".dll",
".so",
".dylib",
".bin",
".dat",
".db",
".sqlite",
".class",
".jar",
".war",
".ear",
}
for root, dirs, files in os.walk(source_ref):
dirs[:] = [d for d in dirs if not d.startswith(".")]
for filename in files:
if filename.startswith("."):
continue
# Skip binary files based on extension
file_ext = os.path.splitext(filename)[1].lower()
if file_ext in binary_extensions:
continue
# Skip __pycache__ directories
if "__pycache__" in root:
continue
file_path: str = os.path.join(root, filename)
try:
file_data_type: DataType = DataTypes.from_content(file_path)
file_loader = file_data_type.get_loader()
file_chunker = file_data_type.get_chunker()
file_source = SourceContent(file_path)
file_result: LoaderResult = file_loader.load(file_source)
file_chunks = file_chunker.chunk(file_result.content)
for chunk_idx, file_chunk in enumerate(file_chunks):
file_metadata: dict[str, Any] = base_metadata.copy()
file_metadata.update(file_result.metadata)
file_metadata["data_type"] = str(file_data_type)
file_metadata["file_path"] = file_path
file_metadata["chunk_index"] = chunk_idx
file_metadata["total_chunks"] = len(file_chunks)
if isinstance(arg, dict):
file_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
).hexdigest()
documents.append(
{
"doc_id": chunk_id,
"content": file_chunk,
"metadata": sanitize_metadata_for_chromadb(
file_metadata
),
}
)
except Exception: # noqa: S112
# Silently skip files that can't be processed
continue
else:
metadata: dict[str, Any] = base_metadata.copy()
if data_type in [
DataType.PDF_FILE,
DataType.TEXT_FILE,
DataType.DOCX,
DataType.CSV,
DataType.JSON,
DataType.XML,
DataType.MDX,
]:
if not os.path.isfile(source_ref):
raise FileNotFoundError(f"File does not exist: {source_ref}")
loader = data_type.get_loader()
chunker = data_type.get_chunker()
source_content = SourceContent(source_ref)
loader_result: LoaderResult = loader.load(source_content)
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
chunk_metadata: dict[str, Any] = metadata.copy()
chunk_metadata.update(loader_result.metadata)
chunk_metadata["data_type"] = str(data_type)
chunk_metadata["chunk_index"] = i
chunk_metadata["total_chunks"] = len(chunks)
chunk_metadata["source"] = source_ref
if isinstance(arg, dict):
chunk_metadata.update(arg.get("metadata", {}))
chunk_id = hashlib.sha256(
f"{loader_result.doc_id}_{i}_{chunk}".encode()
).hexdigest()
documents.append(
{
"doc_id": chunk_id,
"content": chunk,
"metadata": sanitize_metadata_for_chromadb(chunk_metadata),
}
)
if documents:
self._client.add_documents(
collection_name=self.collection_name, documents=documents
)