mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
0b3f00e6 chore: update project version to 0.73.0 and revise uv.lock dependencies (#455) ad19b074 feat: replace embedchain with native crewai adapter (#451) git-subtree-dir: packages/tools git-subtree-split: 0b3f00e67c0dae24d188c292dc99759fd1c841f7
210 lines
7.2 KiB
Python
210 lines
7.2 KiB
Python
import os
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, cast
|
|
|
|
from crewai.rag.embeddings.factory import get_embedding_function
|
|
from crewai.tools import BaseTool
|
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
|
|
|
|
class Adapter(BaseModel, ABC):
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
@abstractmethod
|
|
def query(
|
|
self,
|
|
question: str,
|
|
similarity_threshold: float | None = None,
|
|
limit: int | None = None,
|
|
) -> str:
|
|
"""Query the knowledge base with a question and return the answer."""
|
|
|
|
@abstractmethod
|
|
def add(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Add content to the knowledge base."""
|
|
|
|
|
|
class RagTool(BaseTool):
|
|
class _AdapterPlaceholder(Adapter):
|
|
def query(
|
|
self,
|
|
question: str,
|
|
similarity_threshold: float | None = None,
|
|
limit: int | None = None,
|
|
) -> str:
|
|
raise NotImplementedError
|
|
|
|
def add(self, *args: Any, **kwargs: Any) -> None:
|
|
raise NotImplementedError
|
|
|
|
name: str = "Knowledge base"
|
|
description: str = "A knowledge base that can be used to answer questions."
|
|
summarize: bool = False
|
|
similarity_threshold: float = 0.6
|
|
limit: int = 5
|
|
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
|
config: Any | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def _set_default_adapter(self):
|
|
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
|
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
|
|
|
parsed_config = self._parse_config(self.config)
|
|
|
|
self.adapter = CrewAIRagAdapter(
|
|
collection_name="rag_tool_collection",
|
|
summarize=self.summarize,
|
|
similarity_threshold=self.similarity_threshold,
|
|
limit=self.limit,
|
|
config=parsed_config,
|
|
)
|
|
|
|
return self
|
|
|
|
def _parse_config(self, config: Any) -> Any:
|
|
"""Parse complex config format to extract provider-specific config.
|
|
|
|
Raises:
|
|
ValueError: If the config format is invalid or uses unsupported providers.
|
|
"""
|
|
if config is None:
|
|
return None
|
|
|
|
if isinstance(config, dict) and "provider" in config:
|
|
return config
|
|
|
|
if isinstance(config, dict):
|
|
if "vectordb" in config:
|
|
vectordb_config = config["vectordb"]
|
|
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
|
provider = vectordb_config["provider"]
|
|
provider_config = vectordb_config.get("config", {})
|
|
|
|
supported_providers = ["chromadb", "qdrant"]
|
|
if provider not in supported_providers:
|
|
raise ValueError(
|
|
f"Unsupported vector database provider: '{provider}'. "
|
|
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
|
)
|
|
|
|
embedding_config = config.get("embedding_model")
|
|
embedding_function = None
|
|
if embedding_config and isinstance(embedding_config, dict):
|
|
embedding_function = self._create_embedding_function(
|
|
embedding_config, provider
|
|
)
|
|
|
|
return self._create_provider_config(
|
|
provider, provider_config, embedding_function
|
|
)
|
|
else:
|
|
return None
|
|
else:
|
|
embedding_config = config.get("embedding_model")
|
|
embedding_function = None
|
|
if embedding_config and isinstance(embedding_config, dict):
|
|
embedding_function = self._create_embedding_function(
|
|
embedding_config, "chromadb"
|
|
)
|
|
|
|
return self._create_provider_config("chromadb", {}, embedding_function)
|
|
return config
|
|
|
|
@staticmethod
|
|
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
|
"""Create embedding function for the specified vector database provider."""
|
|
embedding_provider = embedding_config.get("provider")
|
|
embedding_model_config = embedding_config.get("config", {}).copy()
|
|
|
|
if "model" in embedding_model_config:
|
|
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
|
|
|
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
|
|
|
if embedding_provider == "openai" and "api_key" not in factory_config:
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if api_key:
|
|
factory_config["api_key"] = api_key
|
|
|
|
print(f"Creating embedding function with config: {factory_config}")
|
|
|
|
if provider == "chromadb":
|
|
embedding_func = get_embedding_function(factory_config)
|
|
print(f"Created embedding function: {embedding_func}")
|
|
print(f"Embedding function type: {type(embedding_func)}")
|
|
return embedding_func
|
|
|
|
elif provider == "qdrant":
|
|
chromadb_func = get_embedding_function(factory_config)
|
|
|
|
def qdrant_embed_fn(text: str) -> list[float]:
|
|
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
|
|
|
Args:
|
|
text: The input text to embed.
|
|
|
|
Returns:
|
|
A list of floats representing the embedding.
|
|
"""
|
|
embeddings = chromadb_func([text])
|
|
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
|
|
|
return cast(Any, qdrant_embed_fn)
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def _create_provider_config(
|
|
provider: str, provider_config: dict, embedding_function: Any
|
|
) -> Any:
|
|
"""Create proper provider config object."""
|
|
if provider == "chromadb":
|
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
|
|
|
config_kwargs = {}
|
|
if embedding_function:
|
|
config_kwargs["embedding_function"] = embedding_function
|
|
|
|
config_kwargs.update(provider_config)
|
|
|
|
return ChromaDBConfig(**config_kwargs)
|
|
|
|
elif provider == "qdrant":
|
|
from crewai.rag.qdrant.config import QdrantConfig
|
|
|
|
config_kwargs = {}
|
|
if embedding_function:
|
|
config_kwargs["embedding_function"] = embedding_function
|
|
|
|
config_kwargs.update(provider_config)
|
|
|
|
return QdrantConfig(**config_kwargs)
|
|
|
|
return None
|
|
|
|
def add(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.adapter.add(*args, **kwargs)
|
|
|
|
def _run(
|
|
self,
|
|
query: str,
|
|
similarity_threshold: float | None = None,
|
|
limit: int | None = None,
|
|
) -> str:
|
|
threshold = (
|
|
similarity_threshold
|
|
if similarity_threshold is not None
|
|
else self.similarity_threshold
|
|
)
|
|
result_limit = limit if limit is not None else self.limit
|
|
return f"Relevant Content:\n{self.adapter.query(query, similarity_threshold=threshold, limit=result_limit)}"
|