mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
270 lines
9.8 KiB
Python
270 lines
9.8 KiB
Python
"""Adapter for CrewAI's native RAG system."""
|
|
|
|
import hashlib
|
|
from pathlib import Path
|
|
from typing import Any, TypeAlias, TypedDict
|
|
|
|
from crewai.rag.config.types import RagConfigType
|
|
from crewai.rag.config.utils import get_rag_client
|
|
from crewai.rag.core.base_client import BaseClient
|
|
from crewai.rag.factory import create_client
|
|
from crewai.rag.types import BaseRecord, SearchResult
|
|
from pydantic import PrivateAttr
|
|
from typing_extensions import Unpack
|
|
|
|
from crewai_tools.rag.data_types import DataType
|
|
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
|
|
|
|
|
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
|
|
|
|
|
class AddDocumentParams(TypedDict, total=False):
|
|
"""Parameters for adding documents to the RAG system."""
|
|
|
|
data_type: DataType
|
|
metadata: dict[str, Any]
|
|
website: str
|
|
url: str
|
|
file_path: str | Path
|
|
github_url: str
|
|
youtube_url: str
|
|
directory_path: str | Path
|
|
|
|
|
|
class CrewAIRagAdapter(Adapter):
|
|
"""Adapter that uses CrewAI's native RAG system.
|
|
|
|
Supports custom vector database configuration through the config parameter.
|
|
"""
|
|
|
|
collection_name: str = "default"
|
|
summarize: bool = False
|
|
similarity_threshold: float = 0.6
|
|
limit: int = 5
|
|
config: RagConfigType | None = None
|
|
_client: BaseClient | None = PrivateAttr(default=None)
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
"""Initialize the CrewAI RAG client after model initialization."""
|
|
if self.config is not None:
|
|
self._client = create_client(self.config)
|
|
else:
|
|
self._client = get_rag_client()
|
|
self._client.get_or_create_collection(collection_name=self.collection_name)
|
|
|
|
def query(
|
|
self,
|
|
question: str,
|
|
similarity_threshold: float | None = None,
|
|
limit: int | None = None,
|
|
) -> str:
|
|
"""Query the knowledge base with a question.
|
|
|
|
Args:
|
|
question: The question to ask
|
|
similarity_threshold: Minimum similarity score for results (default: 0.6)
|
|
limit: Maximum number of results to return (default: 5)
|
|
|
|
Returns:
|
|
Relevant content from the knowledge base
|
|
"""
|
|
search_limit = limit if limit is not None else self.limit
|
|
search_threshold = (
|
|
similarity_threshold
|
|
if similarity_threshold is not None
|
|
else self.similarity_threshold
|
|
)
|
|
|
|
results: list[SearchResult] = self._client.search(
|
|
collection_name=self.collection_name,
|
|
query=question,
|
|
limit=search_limit,
|
|
score_threshold=search_threshold,
|
|
)
|
|
|
|
if not results:
|
|
return "No relevant content found."
|
|
|
|
contents: list[str] = []
|
|
for result in results:
|
|
content: str = result.get("content", "")
|
|
if content:
|
|
contents.append(content)
|
|
|
|
return "\n\n".join(contents)
|
|
|
|
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
|
"""Add content to the knowledge base.
|
|
|
|
This method handles various input types and converts them to documents
|
|
for the vector database. It supports the data_type parameter for
|
|
compatibility with existing tools.
|
|
|
|
Args:
|
|
*args: Content items to add (strings, paths, or document dicts)
|
|
**kwargs: Additional parameters including data_type, metadata, etc.
|
|
"""
|
|
import os
|
|
|
|
from crewai_tools.rag.base_loader import LoaderResult
|
|
from crewai_tools.rag.data_types import DataType, DataTypes
|
|
from crewai_tools.rag.source_content import SourceContent
|
|
|
|
documents: list[BaseRecord] = []
|
|
data_type: DataType | None = kwargs.get("data_type")
|
|
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
|
|
|
for arg in args:
|
|
source_ref: str
|
|
if isinstance(arg, dict):
|
|
source_ref = str(arg.get("source", arg.get("content", "")))
|
|
else:
|
|
source_ref = str(arg)
|
|
|
|
if not data_type:
|
|
data_type = DataTypes.from_content(source_ref)
|
|
|
|
if data_type == DataType.DIRECTORY:
|
|
if not os.path.isdir(source_ref):
|
|
raise ValueError(f"Directory does not exist: {source_ref}")
|
|
|
|
# Define binary and non-text file extensions to skip
|
|
binary_extensions = {
|
|
".pyc",
|
|
".pyo",
|
|
".png",
|
|
".jpg",
|
|
".jpeg",
|
|
".gif",
|
|
".bmp",
|
|
".ico",
|
|
".svg",
|
|
".webp",
|
|
".pdf",
|
|
".zip",
|
|
".tar",
|
|
".gz",
|
|
".bz2",
|
|
".7z",
|
|
".rar",
|
|
".exe",
|
|
".dll",
|
|
".so",
|
|
".dylib",
|
|
".bin",
|
|
".dat",
|
|
".db",
|
|
".sqlite",
|
|
".class",
|
|
".jar",
|
|
".war",
|
|
".ear",
|
|
}
|
|
|
|
for root, dirs, files in os.walk(source_ref):
|
|
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
|
|
|
for filename in files:
|
|
if filename.startswith("."):
|
|
continue
|
|
|
|
# Skip binary files based on extension
|
|
file_ext = os.path.splitext(filename)[1].lower()
|
|
if file_ext in binary_extensions:
|
|
continue
|
|
|
|
# Skip __pycache__ directories
|
|
if "__pycache__" in root:
|
|
continue
|
|
|
|
file_path: str = os.path.join(root, filename)
|
|
try:
|
|
file_data_type: DataType = DataTypes.from_content(file_path)
|
|
file_loader = file_data_type.get_loader()
|
|
file_chunker = file_data_type.get_chunker()
|
|
|
|
file_source = SourceContent(file_path)
|
|
file_result: LoaderResult = file_loader.load(file_source)
|
|
|
|
file_chunks = file_chunker.chunk(file_result.content)
|
|
|
|
for chunk_idx, file_chunk in enumerate(file_chunks):
|
|
file_metadata: dict[str, Any] = base_metadata.copy()
|
|
file_metadata.update(file_result.metadata)
|
|
file_metadata["data_type"] = str(file_data_type)
|
|
file_metadata["file_path"] = file_path
|
|
file_metadata["chunk_index"] = chunk_idx
|
|
file_metadata["total_chunks"] = len(file_chunks)
|
|
|
|
if isinstance(arg, dict):
|
|
file_metadata.update(arg.get("metadata", {}))
|
|
|
|
chunk_id = hashlib.sha256(
|
|
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
|
).hexdigest()
|
|
|
|
documents.append(
|
|
{
|
|
"doc_id": chunk_id,
|
|
"content": file_chunk,
|
|
"metadata": sanitize_metadata_for_chromadb(
|
|
file_metadata
|
|
),
|
|
}
|
|
)
|
|
except Exception: # noqa: S112
|
|
# Silently skip files that can't be processed
|
|
continue
|
|
else:
|
|
metadata: dict[str, Any] = base_metadata.copy()
|
|
|
|
if data_type in [
|
|
DataType.PDF_FILE,
|
|
DataType.TEXT_FILE,
|
|
DataType.DOCX,
|
|
DataType.CSV,
|
|
DataType.JSON,
|
|
DataType.XML,
|
|
DataType.MDX,
|
|
]:
|
|
if not os.path.isfile(source_ref):
|
|
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
|
|
|
loader = data_type.get_loader()
|
|
chunker = data_type.get_chunker()
|
|
|
|
source_content = SourceContent(source_ref)
|
|
loader_result: LoaderResult = loader.load(source_content)
|
|
|
|
chunks = chunker.chunk(loader_result.content)
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_metadata: dict[str, Any] = metadata.copy()
|
|
chunk_metadata.update(loader_result.metadata)
|
|
chunk_metadata["data_type"] = str(data_type)
|
|
chunk_metadata["chunk_index"] = i
|
|
chunk_metadata["total_chunks"] = len(chunks)
|
|
chunk_metadata["source"] = source_ref
|
|
|
|
if isinstance(arg, dict):
|
|
chunk_metadata.update(arg.get("metadata", {}))
|
|
|
|
chunk_id = hashlib.sha256(
|
|
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
|
).hexdigest()
|
|
|
|
documents.append(
|
|
{
|
|
"doc_id": chunk_id,
|
|
"content": chunk,
|
|
"metadata": sanitize_metadata_for_chromadb(chunk_metadata),
|
|
}
|
|
)
|
|
|
|
if documents:
|
|
self._client.add_documents(
|
|
collection_name=self.collection_name, documents=documents
|
|
)
|