Compare commits

...

14 Commits

Author SHA1 Message Date
Lorenze Jay
1589230833 Merge branch 'main' of github.com:crewAIInc/crewAI into fix/memory-embedder-config 2024-10-22 12:28:07 -07:00
Lorenze Jay
6d0251224e Merge branch 'fix/memory-embedder-config' of github.com:crewAIInc/crewAI into fix/memory-embedder-config 2024-10-22 12:27:50 -07:00
Lorenze Jay
e2f70cb53f updates to add more docs and correct imports with huggingface embedding server enabled 2024-10-22 12:27:47 -07:00
Brandon Hancock
0dd522ddff Merge branch 'main' into fix/memory-embedder-config 2024-10-21 19:41:45 -04:00
Lorenze Jay
5803b3fb69 fixed run types 2024-10-21 16:04:42 -07:00
Lorenze Jay
31c3082740 fixed docs 2024-10-21 14:29:11 -07:00
Lorenze Jay
21afc46c0d Merge branch 'main' of github.com:crewAIInc/crewAI into fix/memory-embedder-config 2024-10-21 14:24:35 -07:00
Lorenze Jay
78882c6de2 rm prints 2024-10-21 14:24:26 -07:00
Lorenze Jay
2786086974 fixes 2024-10-21 14:24:07 -07:00
Lorenze Jay
6b12ac9c0b Merge branch 'main' of github.com:crewAIInc/crewAI into fix/memory-embedder-config 2024-10-21 09:31:56 -07:00
Lorenze Jay
266ecff395 WIP: brandons notes 2024-10-21 09:31:39 -07:00
Lorenze Jay
34d748d18e raise error on unsupported provider 2024-10-21 08:37:42 -07:00
Lorenze Jay
79f527576b some fixes 2024-10-20 18:26:24 -07:00
Lorenze Jay
3fc83c624b ensure original embedding config works 2024-10-20 18:12:57 -07:00
3 changed files with 166 additions and 18 deletions

View File

@@ -105,9 +105,48 @@ my_crew = Crew(
process=Process.sequential,
memory=True,
verbose=True,
embedder=embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
embedder={
"provider": "openai",
"config": {
"model": 'text-embedding-3-small'
}
}
)
```
Alternatively, you can directly pass the OpenAIEmbeddingFunction to the embedder parameter.
Example:
```python Code
from crewai import Crew, Agent, Task, Process
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"),
)
```
### Using Ollama embeddings
```python Code
from crewai import Crew, Agent, Task, Process
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "ollama",
"config": {
"model": "mxbai-embed-large"
}
}
)
```
@@ -122,10 +161,13 @@ my_crew = Crew(
process=Process.sequential,
memory=True,
verbose=True,
embedder=embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-ada-002"
)
embedder={
"provider": "google",
"config": {
"api_key": "<YOUR_API_KEY>",
"model_name": "<model_name>"
}
}
)
```
@@ -181,10 +223,32 @@ my_crew = Crew(
process=Process.sequential,
memory=True,
verbose=True,
embedder=embedding_functions.CohereEmbeddingFunction(
api_key=YOUR_API_KEY,
model_name="<model_name>"
)
embedder={
"provider": "cohere",
"config": {
"api_key": "YOUR_API_KEY",
"model_name": "<model_name>"
}
}
)
```
### Using HuggingFace embeddings
```python Code
from crewai import Crew, Agent, Task, Process
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "huggingface",
"config": {
"api_url": "<api_url>",
}
}
)
```

View File

@@ -16,7 +16,7 @@ class EntityMemory(Memory):
if storage
else RAGStorage(
type="entities",
allow_reset=False,
allow_reset=True,
embedder_config=embedder_config,
crew=crew,
)

View File

@@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from chromadb.api import ClientAPI
from chromadb.api.types import validate_embedding_function
from chromadb import Documents, EmbeddingFunction, Embeddings
from typing import cast
@contextlib.contextmanager
@@ -41,16 +44,93 @@ class RAGStorage(BaseRAGStorage):
self.agents = agents
self.type = type
self.embedder_config = embedder_config or self._create_embedding_function()
self.allow_reset = allow_reset
self._initialize_app()
def _set_embedder_config(self):
import chromadb.utils.embedding_functions as embedding_functions
if self.embedder_config is None:
self.embedder_config = self._create_default_embedding_function()
if isinstance(self.embedder_config, dict):
provider = self.embedder_config.get("provider")
config = self.embedder_config.get("config", {})
model_name = config.get("model")
if provider == "openai":
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
)
elif provider == "azure":
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
)
elif provider == "ollama":
from openai import OpenAI
class OllamaEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
client = OpenAI(
base_url="http://localhost:11434/v1",
api_key=config.get("api_key", "ollama"),
)
try:
response = client.embeddings.create(
input=input, model=model_name
)
embeddings = [item.embedding for item in response.data]
return cast(Embeddings, embeddings)
except Exception as e:
raise e
self.embedder_config = OllamaEmbeddingFunction()
elif provider == "vertexai":
self.embedder_config = (
embedding_functions.GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
)
elif provider == "google":
self.embedder_config = (
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
)
elif provider == "cohere":
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
elif provider == "huggingface":
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
else:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
)
else:
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
self.embedder_config = self.embedder_config
def _initialize_app(self):
import chromadb
from chromadb.config import Settings
self._set_embedder_config()
chroma_client = chromadb.PersistentClient(
path=f"{db_storage_path()}/{self.type}/{self.agents}"
path=f"{db_storage_path()}/{self.type}/{self.agents}",
settings=Settings(allow_reset=self.allow_reset),
)
self.app = chroma_client
try:
@@ -122,11 +202,15 @@ class RAGStorage(BaseRAGStorage):
if self.app:
self.app.reset()
except Exception as e:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
if "attempt to write a readonly database" in str(e):
# Ignore this specific error
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_embedding_function(self):
def _create_default_embedding_function(self):
import chromadb.utils.embedding_functions as embedding_functions
return embedding_functions.OpenAIEmbeddingFunction(