mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-17 04:48:30 +00:00
Compare commits
14 Commits
bugfix/res
...
fix/memory
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1589230833 | ||
|
|
6d0251224e | ||
|
|
e2f70cb53f | ||
|
|
0dd522ddff | ||
|
|
5803b3fb69 | ||
|
|
31c3082740 | ||
|
|
21afc46c0d | ||
|
|
78882c6de2 | ||
|
|
2786086974 | ||
|
|
6b12ac9c0b | ||
|
|
266ecff395 | ||
|
|
34d748d18e | ||
|
|
79f527576b | ||
|
|
3fc83c624b |
@@ -105,9 +105,48 @@ my_crew = Crew(
|
|||||||
process=Process.sequential,
|
process=Process.sequential,
|
||||||
memory=True,
|
memory=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
embedder=embedding_functions.OpenAIEmbeddingFunction(
|
embedder={
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"model": 'text-embedding-3-small'
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
Alternatively, you can directly pass the OpenAIEmbeddingFunction to the embedder parameter.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python Code
|
||||||
|
from crewai import Crew, Agent, Task, Process
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||||
|
|
||||||
|
my_crew = Crew(
|
||||||
|
agents=[...],
|
||||||
|
tasks=[...],
|
||||||
|
process=Process.sequential,
|
||||||
|
memory=True,
|
||||||
|
verbose=True,
|
||||||
|
embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Ollama embeddings
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
from crewai import Crew, Agent, Task, Process
|
||||||
|
|
||||||
|
my_crew = Crew(
|
||||||
|
agents=[...],
|
||||||
|
tasks=[...],
|
||||||
|
process=Process.sequential,
|
||||||
|
memory=True,
|
||||||
|
verbose=True,
|
||||||
|
embedder={
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model": "mxbai-embed-large"
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -122,10 +161,13 @@ my_crew = Crew(
|
|||||||
process=Process.sequential,
|
process=Process.sequential,
|
||||||
memory=True,
|
memory=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
embedder=embedding_functions.OpenAIEmbeddingFunction(
|
embedder={
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
"provider": "google",
|
||||||
model_name="text-embedding-ada-002"
|
"config": {
|
||||||
)
|
"api_key": "<YOUR_API_KEY>",
|
||||||
|
"model_name": "<model_name>"
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -181,10 +223,32 @@ my_crew = Crew(
|
|||||||
process=Process.sequential,
|
process=Process.sequential,
|
||||||
memory=True,
|
memory=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
embedder=embedding_functions.CohereEmbeddingFunction(
|
embedder={
|
||||||
api_key=YOUR_API_KEY,
|
"provider": "cohere",
|
||||||
model_name="<model_name>"
|
"config": {
|
||||||
|
"api_key": "YOUR_API_KEY",
|
||||||
|
"model_name": "<model_name>"
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
### Using HuggingFace embeddings
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
from crewai import Crew, Agent, Task, Process
|
||||||
|
|
||||||
|
my_crew = Crew(
|
||||||
|
agents=[...],
|
||||||
|
tasks=[...],
|
||||||
|
process=Process.sequential,
|
||||||
|
memory=True,
|
||||||
|
verbose=True,
|
||||||
|
embedder={
|
||||||
|
"provider": "huggingface",
|
||||||
|
"config": {
|
||||||
|
"api_url": "<api_url>",
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class EntityMemory(Memory):
|
|||||||
if storage
|
if storage
|
||||||
else RAGStorage(
|
else RAGStorage(
|
||||||
type="entities",
|
type="entities",
|
||||||
allow_reset=False,
|
allow_reset=True,
|
||||||
embedder_config=embedder_config,
|
embedder_config=embedder_config,
|
||||||
crew=crew,
|
crew=crew,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional
|
|||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -41,16 +44,93 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
self.agents = agents
|
self.agents = agents
|
||||||
|
|
||||||
self.type = type
|
self.type = type
|
||||||
self.embedder_config = embedder_config or self._create_embedding_function()
|
|
||||||
self.allow_reset = allow_reset
|
self.allow_reset = allow_reset
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
|
def _set_embedder_config(self):
|
||||||
|
import chromadb.utils.embedding_functions as embedding_functions
|
||||||
|
|
||||||
|
if self.embedder_config is None:
|
||||||
|
self.embedder_config = self._create_default_embedding_function()
|
||||||
|
|
||||||
|
if isinstance(self.embedder_config, dict):
|
||||||
|
provider = self.embedder_config.get("provider")
|
||||||
|
config = self.embedder_config.get("config", {})
|
||||||
|
model_name = config.get("model")
|
||||||
|
if provider == "openai":
|
||||||
|
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
elif provider == "azure":
|
||||||
|
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
api_base=config.get("api_base"),
|
||||||
|
api_type=config.get("api_type", "azure"),
|
||||||
|
api_version=config.get("api_version"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
elif provider == "ollama":
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
class OllamaEmbeddingFunction(EmbeddingFunction):
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:11434/v1",
|
||||||
|
api_key=config.get("api_key", "ollama"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=input, model=model_name
|
||||||
|
)
|
||||||
|
embeddings = [item.embedding for item in response.data]
|
||||||
|
return cast(Embeddings, embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.embedder_config = OllamaEmbeddingFunction()
|
||||||
|
elif provider == "vertexai":
|
||||||
|
self.embedder_config = (
|
||||||
|
embedding_functions.GoogleVertexEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif provider == "google":
|
||||||
|
self.embedder_config = (
|
||||||
|
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif provider == "cohere":
|
||||||
|
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
|
||||||
|
model_name=model_name,
|
||||||
|
api_key=config.get("api_key"),
|
||||||
|
)
|
||||||
|
elif provider == "huggingface":
|
||||||
|
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
|
||||||
|
url=config.get("api_url"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
|
||||||
|
self.embedder_config = self.embedder_config
|
||||||
|
|
||||||
def _initialize_app(self):
|
def _initialize_app(self):
|
||||||
import chromadb
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
self._set_embedder_config()
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=f"{db_storage_path()}/{self.type}/{self.agents}"
|
path=f"{db_storage_path()}/{self.type}/{self.agents}",
|
||||||
|
settings=Settings(allow_reset=self.allow_reset),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app = chroma_client
|
self.app = chroma_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -122,11 +202,15 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
if self.app:
|
if self.app:
|
||||||
self.app.reset()
|
self.app.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if "attempt to write a readonly database" in str(e):
|
||||||
|
# Ignore this specific error
|
||||||
|
pass
|
||||||
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_embedding_function(self):
|
def _create_default_embedding_function(self):
|
||||||
import chromadb.utils.embedding_functions as embedding_functions
|
import chromadb.utils.embedding_functions as embedding_functions
|
||||||
|
|
||||||
return embedding_functions.OpenAIEmbeddingFunction(
|
return embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
|||||||
Reference in New Issue
Block a user