mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-04 13:48:31 +00:00
Compare commits
18 Commits
brandon/cr
...
fix/memory
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1589230833 | ||
|
|
6d0251224e | ||
|
|
e2f70cb53f | ||
|
|
8731915330 | ||
|
|
0dd522ddff | ||
|
|
093259389e | ||
|
|
5803b3fb69 | ||
|
|
31c3082740 | ||
|
|
21afc46c0d | ||
|
|
78882c6de2 | ||
|
|
2786086974 | ||
|
|
6bcb3d1080 | ||
|
|
6b12ac9c0b | ||
|
|
266ecff395 | ||
|
|
71a217b210 | ||
|
|
34d748d18e | ||
|
|
79f527576b | ||
|
|
3fc83c624b |
@@ -62,6 +62,8 @@ os.environ["OPENAI_API_BASE"] = "https://api.your-provider.com/v1"
|
||||
2. Using LLM class attributes:
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="custom-model-name",
|
||||
api_key="your-api-key",
|
||||
@@ -95,9 +97,11 @@ When configuring an LLM for your agent, you have access to a wide range of param
|
||||
| **api_key** | `str` | Your API key for authentication. |
|
||||
|
||||
|
||||
Example:
|
||||
## OpenAI Example Configuration
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4",
|
||||
temperature=0.8,
|
||||
@@ -112,15 +116,31 @@ llm = LLM(
|
||||
)
|
||||
agent = Agent(llm=llm, ...)
|
||||
```
|
||||
|
||||
## Cerebras Example Configuration
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="cerebras/llama-3.1-70b",
|
||||
base_url="https://api.cerebras.ai/v1",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
agent = Agent(llm=llm, ...)
|
||||
```
|
||||
|
||||
## Using Ollama (Local LLMs)
|
||||
|
||||
crewAI supports using Ollama for running open-source models locally:
|
||||
CrewAI supports using Ollama for running open-source models locally:
|
||||
|
||||
1. Install Ollama: [ollama.ai](https://ollama.ai/)
|
||||
2. Run a model: `ollama run llama2`
|
||||
3. Configure agent:
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
agent = Agent(
|
||||
llm=LLM(model="ollama/llama3.1", base_url="http://localhost:11434"),
|
||||
...
|
||||
@@ -132,6 +152,8 @@ agent = Agent(
|
||||
You can change the base API URL for any LLM provider by setting the `base_url` parameter:
|
||||
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="custom-model-name",
|
||||
base_url="https://api.your-provider.com/v1",
|
||||
|
||||
@@ -105,9 +105,48 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
embedder={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": 'text-embedding-3-small'
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
Alternatively, you can directly pass the OpenAIEmbeddingFunction to the embedder parameter.
|
||||
|
||||
Example:
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=OpenAIEmbeddingFunction(api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"),
|
||||
)
|
||||
```
|
||||
|
||||
### Using Ollama embeddings
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder={
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "mxbai-embed-large"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -122,10 +161,13 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-ada-002"
|
||||
)
|
||||
embedder={
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "<YOUR_API_KEY>",
|
||||
"model_name": "<model_name>"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
@@ -181,10 +223,32 @@ my_crew = Crew(
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder=embedding_functions.CohereEmbeddingFunction(
|
||||
api_key=YOUR_API_KEY,
|
||||
model_name="<model_name>"
|
||||
)
|
||||
embedder={
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "YOUR_API_KEY",
|
||||
"model_name": "<model_name>"
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
### Using HuggingFace embeddings
|
||||
|
||||
```python Code
|
||||
from crewai import Crew, Agent, Task, Process
|
||||
|
||||
my_crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
verbose=True,
|
||||
embedder={
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"api_url": "<api_url>",
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
from {{folder_name}}.crew import {{crew_name}}Crew
|
||||
|
||||
# This main file is intended to be a way for you to run your
|
||||
# crew locally, so refrain from adding necessary logic into this file.
|
||||
# crew locally, so refrain from adding unnecessary logic into this file.
|
||||
# Replace with inputs you want to test with, it will automatically
|
||||
# interpolate any tasks and agents information
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class EntityMemory(Memory):
|
||||
if storage
|
||||
else RAGStorage(
|
||||
type="entities",
|
||||
allow_reset=False,
|
||||
allow_reset=True,
|
||||
embedder_config=embedder_config,
|
||||
crew=crew,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Any, Dict, List, Optional
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from typing import cast
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -41,16 +44,93 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.agents = agents
|
||||
|
||||
self.type = type
|
||||
self.embedder_config = embedder_config or self._create_embedding_function()
|
||||
|
||||
self.allow_reset = allow_reset
|
||||
self._initialize_app()
|
||||
|
||||
def _set_embedder_config(self):
|
||||
import chromadb.utils.embedding_functions as embedding_functions
|
||||
|
||||
if self.embedder_config is None:
|
||||
self.embedder_config = self._create_default_embedding_function()
|
||||
|
||||
if isinstance(self.embedder_config, dict):
|
||||
provider = self.embedder_config.get("provider")
|
||||
config = self.embedder_config.get("config", {})
|
||||
model_name = config.get("model")
|
||||
if provider == "openai":
|
||||
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
)
|
||||
elif provider == "azure":
|
||||
self.embedder_config = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
)
|
||||
elif provider == "ollama":
|
||||
from openai import OpenAI
|
||||
|
||||
class OllamaEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:11434/v1",
|
||||
api_key=config.get("api_key", "ollama"),
|
||||
)
|
||||
try:
|
||||
response = client.embeddings.create(
|
||||
input=input, model=model_name
|
||||
)
|
||||
embeddings = [item.embedding for item in response.data]
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.embedder_config = OllamaEmbeddingFunction()
|
||||
elif provider == "vertexai":
|
||||
self.embedder_config = (
|
||||
embedding_functions.GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
)
|
||||
elif provider == "google":
|
||||
self.embedder_config = (
|
||||
embedding_functions.GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
)
|
||||
elif provider == "cohere":
|
||||
self.embedder_config = embedding_functions.CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
elif provider == "huggingface":
|
||||
self.embedder_config = embedding_functions.HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: [openai, azure, ollama, vertexai, google, cohere, huggingface]"
|
||||
)
|
||||
else:
|
||||
validate_embedding_function(self.embedder_config) # type: ignore # used for validating embedder_config if defined a embedding function/class
|
||||
self.embedder_config = self.embedder_config
|
||||
|
||||
def _initialize_app(self):
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=f"{db_storage_path()}/{self.type}/{self.agents}"
|
||||
path=f"{db_storage_path()}/{self.type}/{self.agents}",
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
@@ -122,11 +202,15 @@ class RAGStorage(BaseRAGStorage):
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
if "attempt to write a readonly database" in str(e):
|
||||
# Ignore this specific error
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||
)
|
||||
|
||||
def _create_embedding_function(self):
|
||||
def _create_default_embedding_function(self):
|
||||
import chromadb.utils.embedding_functions as embedding_functions
|
||||
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import json
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from crewai_tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai import Agent, Task
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
|
||||
|
||||
@@ -44,36 +44,12 @@ example_task = Task(
|
||||
)
|
||||
|
||||
|
||||
def test_random_number_tool_usage():
|
||||
crew = Crew(
|
||||
agents=[example_agent],
|
||||
tasks=[example_task],
|
||||
)
|
||||
|
||||
with patch.object(random, "randint", return_value=42):
|
||||
result = crew.kickoff()
|
||||
|
||||
assert "42" in result.raw
|
||||
|
||||
|
||||
def test_random_number_tool_range():
|
||||
tool = RandomNumberTool()
|
||||
result = tool._run(1, 10)
|
||||
assert 1 <= result <= 10
|
||||
|
||||
|
||||
def test_random_number_tool_with_crew():
|
||||
crew = Crew(
|
||||
agents=[example_agent],
|
||||
tasks=[example_task],
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Check if the result contains a number between 1 and 100
|
||||
assert any(str(num) in result.raw for num in range(1, 101))
|
||||
|
||||
|
||||
def test_random_number_tool_invalid_range():
|
||||
tool = RandomNumberTool()
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user