Files
crewAI/packages/tools/crewai_tools/rag/core.py

233 lines
8.1 KiB
Python

import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
import chromadb
import litellm
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.tools.rag.rag_tool import Adapter
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.rag.misc import compute_sha256
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
self.model = model
self.kwargs = kwargs
def embed_text(self, text: str) -> List[float]:
try:
response = litellm.embedding(
model=self.model,
input=[text],
**self.kwargs
)
return response.data[0]['embedding']
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise
def embed_batch(self, texts: List[str]) -> List[List[float]]:
if not texts:
return []
try:
response = litellm.embedding(
model=self.model,
input=texts,
**self.kwargs
)
return [data['embedding'] for data in response.data]
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise
class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
data_type: DataType = DataType.TEXT
source: Optional[str] = None
class RAG(Adapter):
collection_name: str = "crewai_knowledge_base"
persist_directory: Optional[str] = None
embedding_model: str = "text-embedding-3-large"
summarize: bool = False
top_k: int = 5
embedding_config: Dict[str, Any] = Field(default_factory=dict)
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
try:
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
)
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
raise
super().model_post_init(__context)
def add(
self,
content: str | Path,
data_type: Optional[Union[str, DataType]] = None,
metadata: Optional[Dict[str, Any]] = None,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Any
) -> None:
source_content = SourceContent(content)
data_type = self._get_data_type(data_type=data_type, content=source_content)
if not loader:
loader = data_type.get_loader()
if not chunker:
chunker = data_type.get_chunker()
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
if existing_doc_id == doc_id:
logger.warning(f"Document with source {loader_result.source} already exists")
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata['chunk_index'] = i
documents.append(Document(
id=compute_sha256(chunk),
content=chunk,
metadata=doc_metadata,
data_type=data_type,
source=loader_result.source
))
if not documents:
logger.warning("No documents to add")
return
contents = [doc.content for doc in documents]
try:
embeddings = self._embedding_service.embed_batch(contents)
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
return
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update({
"data_type": doc.data_type.value,
"source": doc.source,
"doc_id": doc_id
})
metadatas.append(doc_metadata)
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
try:
question_embedding = self._embedding_service.embed_text(question)
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"]
)
if not results or not results.get("documents") or not results["documents"][0]:
return "No relevant content found."
documents = results["documents"][0]
metadatas = results.get("metadatas", [None])[0] or []
distances = results.get("distances", [None])[0] or []
# Return sources with relevance scores
formatted_results = []
for i, doc in enumerate(documents):
metadata = metadatas[i] if i < len(metadatas) else {}
distance = distances[i] if i < len(distances) else 1.0
source = metadata.get("source", "unknown") if metadata else "unknown"
score = 1 - distance if distance is not None else 0 # Convert distance to similarity
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
return "\n\n".join(formatted_results)
except Exception as e:
logger.error(f"Query failed: {e}")
return f"Error querying knowledge base: {e}"
def delete_collection(self) -> None:
try:
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
def get_collection_info(self) -> Dict[str, Any]:
try:
count = self._collection.count()
return {
"name": self.collection_name,
"count": count,
"embedding_model": self.embedding_model
}
except Exception as e:
logger.error(f"Failed to get collection info: {e}")
return {"error": str(e)}
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
try:
if isinstance(data_type, str):
return DataType(data_type)
except Exception as e:
pass
return content.data_type