* initial knowledge

* WIP

* Adding core knowledge sources

* Improve types and better support for file paths

* added additional sources

* fix linting

* update yaml to include optional deps

* adding in lorenze feedback

* ensure embeddings are persisted

* improvements all around Knowledge class

* return this

* properly reset memory

* properly reset memory+knowledge

* consolodation and improvements

* linted

* cleanup rm unused embedder

* fix test

* fix duplicate

* generating cassettes for knowledge test

* updated default embedder

* None embedder to use default on pipeline cloning

* improvements

* fixed text_file_knowledge

* mypysrc fixes

* type check fixes

* added extra cassette

* just mocks

* linted

* mock knowledge query to not spin up db

* linted

* verbose run

* put a flag

* fix

* adding docs

* better docs

* improvements from review

* more docs

* linted

* rm print

* more fixes

* clearer docs

* added docstrings and type hints for cli

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com>
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-11-20 18:40:08 -05:00
committed by GitHub
parent fde1ee45f9
commit 14a36d3f5e
37 changed files with 2302 additions and 266 deletions

View File

@@ -1,7 +1,9 @@
import warnings
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.pipeline import Pipeline
from crewai.process import Process
@@ -15,4 +17,14 @@ warnings.filterwarnings(
module="pydantic.main",
)
__version__ = "0.80.0"
__all__ = ["Agent", "Crew", "Process", "Task", "Pipeline", "Router", "LLM", "Flow"]
__all__ = [
"Agent",
"Crew",
"Process",
"Task",
"Pipeline",
"Router",
"LLM",
"Flow",
"Knowledge",
]

View File

@@ -11,8 +11,8 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.token_counter_callback import TokenCalcHandler
@@ -52,6 +52,7 @@ class Agent(BaseAgent):
role: The role of the agent.
goal: The objective of the agent.
backstory: The backstory of the agent.
knowledge: The knowledge base of the agent.
config: Dict representation of agent configuration.
llm: The language model that will run the agent.
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
@@ -272,6 +273,18 @@ class Agent(BaseAgent):
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
# Integrate the knowledge base
if self.crew and self.crew.knowledge:
knowledge_snippets = self.crew.knowledge.query([task.prompt()])
valid_snippets = [
result["context"]
for result in knowledge_snippets
if result and result.get("context")
]
if valid_snippets:
formatted_knowledge = "\n".join(valid_snippets)
task_prompt += f"\n\nAdditional Information:\n{formatted_knowledge}"
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)

View File

@@ -136,6 +136,7 @@ def log_tasks_outputs() -> None:
@click.option("-l", "--long", is_flag=True, help="Reset LONG TERM memory")
@click.option("-s", "--short", is_flag=True, help="Reset SHORT TERM memory")
@click.option("-e", "--entities", is_flag=True, help="Reset ENTITIES memory")
@click.option("-kn", "--knowledge", is_flag=True, help="Reset KNOWLEDGE storage")
@click.option(
"-k",
"--kickoff-outputs",
@@ -143,17 +144,24 @@ def log_tasks_outputs() -> None:
help="Reset LATEST KICKOFF TASK OUTPUTS",
)
@click.option("-a", "--all", is_flag=True, help="Reset ALL memories")
def reset_memories(long, short, entities, kickoff_outputs, all):
def reset_memories(
long: bool,
short: bool,
entities: bool,
knowledge: bool,
kickoff_outputs: bool,
all: bool,
) -> None:
"""
Reset the crew memories (long, short, entity, latest_crew_kickoff_ouputs). This will delete all the data saved.
"""
try:
if not all and not (long or short or entities or kickoff_outputs):
if not all and not (long or short or entities or knowledge or kickoff_outputs):
click.echo(
"Please specify at least one memory type to reset using the appropriate flags."
)
return
reset_memories_command(long, short, entities, kickoff_outputs, all)
reset_memories_command(long, short, entities, knowledge, kickoff_outputs, all)
except Exception as e:
click.echo(f"An error occurred while resetting memories: {e}", err=True)

View File

@@ -5,9 +5,17 @@ from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
def reset_memories_command(
long,
short,
entity,
knowledge,
kickoff_outputs,
all,
) -> None:
"""
Reset the crew memories.
@@ -17,6 +25,7 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
entity (bool): Whether to reset the entity memory.
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
all (bool): Whether to reset all memories.
knowledge (bool): Whether to reset the knowledge.
"""
try:
@@ -25,6 +34,7 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
EntityMemory().reset()
LongTermMemory().reset()
TaskOutputStorageHandler().reset()
KnowledgeStorage().reset()
click.echo("All memories have been reset.")
else:
if long:
@@ -40,6 +50,9 @@ def reset_memories_command(long, short, entity, kickoff_outputs, all) -> None:
if kickoff_outputs:
TaskOutputStorageHandler().reset()
click.echo("Latest Kickoff outputs stored has been reset.")
if knowledge:
KnowledgeStorage().reset()
click.echo("Knowledge has been reset.")
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)

View File

@@ -27,6 +27,7 @@ from crewai.llm import LLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.knowledge.knowledge import Knowledge
from crewai.memory.user.user_memory import UserMemory
from crewai.process import Process
from crewai.task import Task
@@ -201,6 +202,10 @@ class Crew(BaseModel):
default=[],
description="List of execution logs for tasks",
)
knowledge: Optional[Dict[str, Any]] = Field(
default=None, description="Knowledge for the crew. Add knowledge sources to the knowledge object."
)
@field_validator("id", mode="before")
@classmethod
@@ -275,6 +280,15 @@ class Crew(BaseModel):
self._user_memory = None
return self
@model_validator(mode="after")
def create_crew_knowledge(self) -> "Crew":
if self.knowledge:
try:
self.knowledge = Knowledge(**self.knowledge) if isinstance(self.knowledge, dict) else self.knowledge
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid knowledge configuration: {str(e)}")
return self
@model_validator(mode="after")
def check_manager_llm(self):
"""Validates that the language model is set when using hierarchical process."""

View File

View File

@@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from typing import List
import numpy as np
class BaseEmbedder(ABC):
"""
Abstract base class for text embedding models
"""
@abstractmethod
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of text chunks
Args:
chunks: List of text chunks to embed
Returns:
Array of embeddings
"""
pass
@abstractmethod
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of texts
Args:
texts: List of texts to embed
Returns:
Array of embeddings
"""
pass
@abstractmethod
def embed_text(self, text: str) -> np.ndarray:
"""
Generate embedding for a single text
Args:
text: Text to embed
Returns:
Embedding array
"""
pass
@property
@abstractmethod
def dimension(self) -> int:
"""Get the dimension of the embeddings"""
pass

View File

@@ -0,0 +1,93 @@
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
from .base_embedder import BaseEmbedder
try:
from fastembed_gpu import TextEmbedding # type: ignore
FASTEMBED_AVAILABLE = True
except ImportError:
try:
from fastembed import TextEmbedding
FASTEMBED_AVAILABLE = True
except ImportError:
FASTEMBED_AVAILABLE = False
class FastEmbed(BaseEmbedder):
"""
A wrapper class for text embedding models using FastEmbed
"""
def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[Union[str, Path]] = None,
):
"""
Initialize the embedding model
Args:
model_name: Name of the model to use
cache_dir: Directory to cache the model
gpu: Whether to use GPU acceleration
"""
if not FASTEMBED_AVAILABLE:
raise ImportError(
"FastEmbed is not installed. Please install it with: "
"uv pip install fastembed or uv pip install fastembed-gpu for GPU support"
)
self.model = TextEmbedding(
model_name=model_name,
cache_dir=str(cache_dir) if cache_dir else None,
)
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
"""
Generate embeddings for a list of text chunks
Args:
chunks: List of text chunks to embed
Returns:
List of embeddings
"""
embeddings = list(self.model.embed(chunks))
return embeddings
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
"""
Generate embeddings for a list of texts
Args:
texts: List of texts to embed
Returns:
List of embeddings
"""
embeddings = list(self.model.embed(texts))
return embeddings
def embed_text(self, text: str) -> np.ndarray:
"""
Generate embedding for a single text
Args:
text: Text to embed
Returns:
Embedding array
"""
return self.embed_texts([text])[0]
@property
def dimension(self) -> int:
"""Get the dimension of the embeddings"""
# Generate a test embedding to get dimensions
test_embed = self.embed_text("test")
return len(test_embed)

View File

@@ -0,0 +1,54 @@
import os
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.utilities.logger import Logger
from crewai.utilities.constants import DEFAULT_SCORE_THRESHOLD
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
embedder_config: Optional[Dict[str, Any]] = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
embedder_config: Optional[Dict[str, Any]] = None
def __init__(self, embedder_config: Optional[Dict[str, Any]] = None, **data):
super().__init__(**data)
self.storage = KnowledgeStorage(embedder_config=embedder_config or None)
try:
for source in self.sources:
source.add()
except Exception as e:
Logger(verbose=True).log(
"warning",
f"Failed to init knowledge: {e}",
color="yellow",
)
def query(
self, query: List[str], limit: int = 3, preference: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
"""
results = self.storage.search(
query,
limit,
filter={"preference": preference} if preference else None,
score_threshold=DEFAULT_SCORE_THRESHOLD,
)
return results

View File

View File

@@ -0,0 +1,36 @@
from pathlib import Path
from typing import Union, List
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from typing import Dict, Any
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
class BaseFileKnowledgeSource(BaseKnowledgeSource):
"""Base class for knowledge sources that load content from files."""
file_path: Union[Path, List[Path]] = Field(...)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
def model_post_init(self, _):
"""Post-initialization method to load content."""
self.content = self.load_content()
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses."""
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
for path in paths:
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
if not path.is_file():
raise ValueError(f"Path is not a file: {path}")
return {}
def save_documents(self, metadata: Dict[str, Any]):
"""Save the documents to the storage."""
chunk_metadatas = [metadata.copy() for _ in self.chunks]
self.storage.save(self.chunks, chunk_metadatas)

View File

@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
class BaseKnowledgeSource(BaseModel, ABC):
"""Abstract base class for knowledge sources."""
chunk_size: int = 4000
chunk_overlap: int = 200
chunks: List[str] = Field(default_factory=list)
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
metadata: Dict[str, Any] = Field(default_factory=dict)
@abstractmethod
def load_content(self) -> Dict[Any, str]:
"""Load and preprocess content from the source."""
pass
@abstractmethod
def add(self) -> None:
"""Process content, chunk it, compute embeddings, and save them."""
pass
def get_embeddings(self) -> List[np.ndarray]:
"""Return the list of embeddings for the chunks."""
return self.chunk_embeddings
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]
def save_documents(self, metadata: Dict[str, Any]):
"""
Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
"""
self.storage.save(self.chunks, metadata)

View File

@@ -0,0 +1,44 @@
import csv
from typing import Dict, List
from pathlib import Path
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
class CSVKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries CSV file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess CSV file content."""
super().load_content() # Validate the file path
file_path = (
self.file_path[0] if isinstance(self.file_path, list) else self.file_path
)
file_path = Path(file_path) if isinstance(file_path, str) else file_path
with open(file_path, "r", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
content = ""
for row in reader:
content += " ".join(row) + "\n"
return {file_path: content}
def add(self) -> None:
"""
Add CSV file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
self.save_documents(metadata=self.metadata)
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]

View File

@@ -0,0 +1,56 @@
from typing import Dict, List
from pathlib import Path
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
class ExcelKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries Excel file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess Excel file content."""
super().load_content() # Validate the file path
pd = self._import_dependencies()
if isinstance(self.file_path, list):
file_path = self.file_path[0]
else:
file_path = self.file_path
df = pd.read_excel(file_path)
content = df.to_csv(index=False)
return {file_path: content}
def _import_dependencies(self):
"""Dynamically import dependencies."""
try:
import openpyxl # noqa
import pandas as pd
return pd
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
)
def add(self) -> None:
"""
Add Excel file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
# Convert dictionary values to a single string if content is a dictionary
if isinstance(self.content, dict):
content_str = "\n".join(str(value) for value in self.content.values())
else:
content_str = str(self.content)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
self.save_documents(metadata=self.metadata)
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]

View File

@@ -0,0 +1,54 @@
import json
from typing import Any, Dict, List
from pathlib import Path
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
class JSONKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries JSON file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess JSON file content."""
super().load_content() # Validate the file path
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content: Dict[Path, str] = {}
for path in paths:
with open(path, "r", encoding="utf-8") as json_file:
data = json.load(json_file)
content[path] = self._json_to_text(data)
return content
def _json_to_text(self, data: Any, level: int = 0) -> str:
"""Recursively convert JSON data to a text representation."""
text = ""
indent = " " * level
if isinstance(data, dict):
for key, value in data.items():
text += f"{indent}{key}: {self._json_to_text(value, level + 1)}\n"
elif isinstance(data, list):
for item in data:
text += f"{indent}- {self._json_to_text(item, level + 1)}\n"
else:
text += f"{str(data)}"
return text
def add(self) -> None:
"""
Add JSON file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
self.save_documents(metadata=self.metadata)
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]

View File

@@ -0,0 +1,54 @@
from typing import List, Dict
from pathlib import Path
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries PDF file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess PDF file content."""
super().load_content() # Validate the file paths
pdfplumber = self._import_pdfplumber()
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content = {}
for path in paths:
text = ""
with pdfplumber.open(path) as pdf:
for page in pdf.pages:
page_text = page.extract_text()
if page_text:
text += page_text + "\n"
content[path] = text
return content
def _import_pdfplumber(self):
"""Dynamically import pdfplumber."""
try:
import pdfplumber
return pdfplumber
except ImportError:
raise ImportError(
"pdfplumber is not installed. Please install it with: pip install pdfplumber"
)
def add(self) -> None:
"""
Add PDF file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for _, text in self.content.items():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self.save_documents(metadata=self.metadata)
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]

View File

@@ -0,0 +1,33 @@
from typing import List
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings."""
content: str = Field(...)
def model_post_init(self, _):
"""Post-initialization method to validate content."""
self.load_content()
def load_content(self):
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
def add(self) -> None:
"""Add string content to the knowledge source, chunk it, compute embeddings, and save them."""
new_chunks = self._chunk_text(self.content)
self.chunks.extend(new_chunks)
self.save_documents(metadata=self.metadata)
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]

View File

@@ -0,0 +1,35 @@
from typing import Dict, List
from pathlib import Path
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries text file content using embeddings."""
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess text file content."""
super().load_content()
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content = {}
for path in paths:
with path.open("r", encoding="utf-8") as f:
content[path] = f.read() # type: ignore
return content
def add(self) -> None:
"""
Add text file content to the knowledge source, chunk it, compute embeddings,
and save the embeddings.
"""
for _, text in self.content.items():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
self.save_documents(metadata=self.metadata)
def _chunk_text(self, text: str) -> List[str]:
"""Utility method to split text into chunks."""
return [
text[i : i + self.chunk_size]
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]

View File

View File

@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
class BaseKnowledgeStorage(ABC):
"""Abstract base class for knowledge storage implementations."""
@abstractmethod
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
"""Search for documents in the knowledge base."""
pass
@abstractmethod
def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
) -> None:
"""Save documents to the knowledge base."""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
pass

View File

@@ -0,0 +1,132 @@
import contextlib
import io
import logging
import chromadb
import os
from crewai.utilities.paths import db_storage_path
from typing import Optional, List
from typing import Dict, Any
from crewai.utilities import EmbeddingConfigurator
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
import hashlib
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
class KnowledgeStorage(BaseKnowledgeStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
collection: Optional[chromadb.Collection] = None
def __init__(self, embedder_config: Optional[Dict[str, Any]] = None):
self._initialize_app(embedder_config or {})
def search(
self,
query: List[str],
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
with suppress_logging():
if self.collection:
fetched = self.collection.query(
query_texts=query,
n_results=limit,
where=filter,
)
results = []
for i in range(len(fetched["ids"][0])): # type: ignore
result = {
"id": fetched["ids"][0][i], # type: ignore
"metadata": fetched["metadatas"][0][i], # type: ignore
"context": fetched["documents"][0][i], # type: ignore
"score": fetched["distances"][0][i], # type: ignore
}
if result["score"] >= score_threshold: # type: ignore
results.append(result)
return results
else:
raise Exception("Collection not initialized")
def _initialize_app(self, embedder_config: Optional[Dict[str, Any]] = None):
import chromadb
from chromadb.config import Settings
self._set_embedder_config(embedder_config)
chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/knowledge",
settings=Settings(allow_reset=True),
)
self.app = chroma_client
try:
self.collection = self.app.get_or_create_collection(name="knowledge")
except Exception:
raise Exception("Failed to create or get collection")
def reset(self):
if self.app:
self.app.reset()
def save(
self, documents: List[str], metadata: Dict[str, Any] | List[Dict[str, Any]]
):
if self.collection:
metadatas = [metadata] if isinstance(metadata, dict) else metadata
ids = [
hashlib.sha256(doc.encode("utf-8")).hexdigest() for doc in documents
]
self.collection.upsert(
documents=documents,
metadatas=metadatas,
ids=ids,
)
else:
raise Exception("Collection not initialized")
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def _set_embedder_config(
self, embedder_config: Optional[Dict[str, Any]] = None
) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder_config = (
EmbeddingConfigurator().configure_embedder(embedder_config)
if embedder_config
else self._create_default_embedding_function()
)

View File

@@ -4,13 +4,12 @@ import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from crewai.utilities import EmbeddingConfigurator
@contextlib.contextmanager
@@ -51,133 +50,8 @@ class RAGStorage(BaseRAGStorage):
self._initialize_app()
def _set_embedder_config(self):
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
self.embedder_config = OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
self.embedder_config = OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
elif provider == "vertexai":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
self.embedder_config = GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "google":
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
self.embedder_config = GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "cohere":
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
self.embedder_config = CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "bedrock":
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
self.embedder_config = AmazonBedrockEmbeddingFunction(
session=config.get("session"),
)
elif provider == "huggingface":
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
self.embedder_config = HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
elif provider == "watson":
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import (
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
self.embedder_config = WatsonEmbeddingFunction()
else:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface, watson]"
)
else:
validate_embedding_function(self.embedder_config)
self.embedder_config = self.embedder_config
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
import chromadb

View File

@@ -10,6 +10,7 @@ from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
from .embedding_configurator import EmbeddingConfigurator
__all__ = [
"Converter",
@@ -23,4 +24,5 @@ __all__ = [
"RPMController",
"YamlParser",
"LLMContextLengthExceededException",
"EmbeddingConfigurator",
]

View File

@@ -1,2 +1,3 @@
TRAINING_DATA_FILE = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
DEFAULT_SCORE_THRESHOLD = 0.35

View File

@@ -0,0 +1,183 @@
import os
from typing import Any, Dict, cast
from chromadb import EmbeddingFunction, Documents, Embeddings
from chromadb.api.types import validate_embedding_function
class EmbeddingConfigurator:
def __init__(self):
self.embedding_functions = {
"openai": self._configure_openai,
"azure": self._configure_azure,
"ollama": self._configure_ollama,
"vertexai": self._configure_vertexai,
"google": self._configure_google,
"cohere": self._configure_cohere,
"bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface,
"watson": self._configure_watson,
}
def configure_embedder(
self,
embedder_config: Dict[str, Any] | None = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
return self._create_default_embedding_function()
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model")
if isinstance(provider, EmbeddingFunction):
try:
validate_embedding_function(provider)
return provider
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
if provider not in self.embedding_functions:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
return self.embedding_functions[provider](config, model_name)
@staticmethod
def _create_default_embedding_function():
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
@staticmethod
def _configure_openai(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
@staticmethod
def _configure_azure(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
@staticmethod
def _configure_ollama(config, model_name):
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
return OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
@staticmethod
def _configure_vertexai(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_google(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_cohere(config, model_name):
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
return CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_bedrock(config, model_name):
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
return AmazonBedrockEmbeddingFunction(
session=config.get("session"),
)
@staticmethod
def _configure_huggingface(config, model_name):
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
return HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
@staticmethod
def _configure_watson(config, model_name):
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
return WatsonEmbeddingFunction()