Squashed 'packages/tools/' content from commit 78317b9c

git-subtree-dir: packages/tools
git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
This commit is contained in:
Greyson Lalonde
2025-09-12 21:58:02 -04:00
commit e16606672a
303 changed files with 49010 additions and 0 deletions

232
crewai_tools/rag/core.py Normal file
View File

@@ -0,0 +1,232 @@
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4
import chromadb
import litellm
from pydantic import BaseModel, Field, PrivateAttr
from crewai_tools.tools.rag.rag_tool import Adapter
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.base_loader import BaseLoader
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.rag.misc import compute_sha256
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
self.model = model
self.kwargs = kwargs
def embed_text(self, text: str) -> List[float]:
try:
response = litellm.embedding(
model=self.model,
input=[text],
**self.kwargs
)
return response.data[0]['embedding']
except Exception as e:
logger.error(f"Error generating embedding: {e}")
raise
def embed_batch(self, texts: List[str]) -> List[List[float]]:
if not texts:
return []
try:
response = litellm.embedding(
model=self.model,
input=texts,
**self.kwargs
)
return [data['embedding'] for data in response.data]
except Exception as e:
logger.error(f"Error generating batch embeddings: {e}")
raise
class Document(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
data_type: DataType = DataType.TEXT
source: Optional[str] = None
class RAG(Adapter):
collection_name: str = "crewai_knowledge_base"
persist_directory: Optional[str] = None
embedding_model: str = "text-embedding-3-large"
summarize: bool = False
top_k: int = 5
embedding_config: Dict[str, Any] = Field(default_factory=dict)
_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_embedding_service: EmbeddingService = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
try:
if self.persist_directory:
self._client = chromadb.PersistentClient(path=self.persist_directory)
else:
self._client = chromadb.Client()
self._collection = self._client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine", "description": "CrewAI Knowledge Base"}
)
self._embedding_service = EmbeddingService(model=self.embedding_model, **self.embedding_config)
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
raise
super().model_post_init(__context)
def add(
self,
content: str | Path,
data_type: Optional[Union[str, DataType]] = None,
metadata: Optional[Dict[str, Any]] = None,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
**kwargs: Any
) -> None:
source_content = SourceContent(content)
data_type = self._get_data_type(data_type=data_type, content=source_content)
if not loader:
loader = data_type.get_loader()
if not chunker:
chunker = data_type.get_chunker()
loader_result = loader.load(source_content)
doc_id = loader_result.doc_id
existing_doc = self._collection.get(where={"source": source_content.source_ref}, limit=1)
existing_doc_id = existing_doc and existing_doc['metadatas'][0]['doc_id'] if existing_doc['metadatas'] else None
if existing_doc_id == doc_id:
logger.warning(f"Document with source {loader_result.source} already exists")
return
# Document with same source ref does exists but the content has changed, deleting the oldest reference
if existing_doc_id and existing_doc_id != loader_result.doc_id:
logger.warning(f"Deleting old document with doc_id {existing_doc_id}")
self._collection.delete(where={"doc_id": existing_doc_id})
documents = []
chunks = chunker.chunk(loader_result.content)
for i, chunk in enumerate(chunks):
doc_metadata = (metadata or {}).copy()
doc_metadata['chunk_index'] = i
documents.append(Document(
id=compute_sha256(chunk),
content=chunk,
metadata=doc_metadata,
data_type=data_type,
source=loader_result.source
))
if not documents:
logger.warning("No documents to add")
return
contents = [doc.content for doc in documents]
try:
embeddings = self._embedding_service.embed_batch(contents)
except Exception as e:
logger.error(f"Failed to generate embeddings: {e}")
return
ids = [doc.id for doc in documents]
metadatas = []
for doc in documents:
doc_metadata = doc.metadata.copy()
doc_metadata.update({
"data_type": doc.data_type.value,
"source": doc.source,
"doc_id": doc_id
})
metadatas.append(doc_metadata)
try:
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=contents,
metadatas=metadatas,
)
logger.info(f"Added {len(documents)} documents to knowledge base")
except Exception as e:
logger.error(f"Failed to add documents to ChromaDB: {e}")
def query(self, question: str, where: Optional[Dict[str, Any]] = None) -> str:
try:
question_embedding = self._embedding_service.embed_text(question)
results = self._collection.query(
query_embeddings=[question_embedding],
n_results=self.top_k,
where=where,
include=["documents", "metadatas", "distances"]
)
if not results or not results.get("documents") or not results["documents"][0]:
return "No relevant content found."
documents = results["documents"][0]
metadatas = results.get("metadatas", [None])[0] or []
distances = results.get("distances", [None])[0] or []
# Return sources with relevance scores
formatted_results = []
for i, doc in enumerate(documents):
metadata = metadatas[i] if i < len(metadatas) else {}
distance = distances[i] if i < len(distances) else 1.0
source = metadata.get("source", "unknown") if metadata else "unknown"
score = 1 - distance if distance is not None else 0 # Convert distance to similarity
formatted_results.append(f"[Source: {source}, Relevance: {score:.3f}]\n{doc}")
return "\n\n".join(formatted_results)
except Exception as e:
logger.error(f"Query failed: {e}")
return f"Error querying knowledge base: {e}"
def delete_collection(self) -> None:
try:
self._client.delete_collection(self.collection_name)
logger.info(f"Deleted collection: {self.collection_name}")
except Exception as e:
logger.error(f"Failed to delete collection: {e}")
def get_collection_info(self) -> Dict[str, Any]:
try:
count = self._collection.count()
return {
"name": self.collection_name,
"count": count,
"embedding_model": self.embedding_model
}
except Exception as e:
logger.error(f"Failed to get collection info: {e}")
return {"error": str(e)}
def _get_data_type(self, content: SourceContent, data_type: str | DataType | None = None) -> DataType:
try:
if isinstance(data_type, str):
return DataType(data_type)
except Exception as e:
pass
return content.data_type