Compare commits

...

3 Commits

Author SHA1 Message Date
Brandon Hancock (bhancock_ai)
8ae073c681 Merge branch 'main' into brandon/fix-google-docs 2025-02-07 16:52:40 -05:00
Lorenze Jay
e529766391 Enhance embedding configuration with custom embedder support (#2060)
* Enhance embedding configuration with custom embedder support

- Add support for custom embedding functions in EmbeddingConfigurator
- Update type hints for embedder configuration
- Extend configuration options for various embedding providers
- Add optional embedder configuration to Memory class

* added docs

* Refine custom embedder configuration support

- Update custom embedder configuration method to handle custom embedding functions
- Modify type hints for embedder configuration
- Remove unused model_name parameter in custom embedder configuration
2025-02-07 16:49:46 -05:00
Brandon Hancock
9a34d9c01e clean up google docs 2025-02-07 16:47:05 -05:00
5 changed files with 90 additions and 16 deletions

View File

@@ -463,26 +463,32 @@ Learn how to get the most out of your LLM configuration:
<Accordion title="Google">
```python Code
# Option 1. Gemini accessed with an API key.
# Option 1: Gemini accessed with an API key.
# https://ai.google.dev/gemini-api/docs/api-key
GEMINI_API_KEY=<your-api-key>
# Option 2. Vertex AI IAM credentials for Gemini, Anthropic, and anything in the Model Garden.
# Option 2: Vertex AI IAM credentials for Gemini, Anthropic, and Model Garden.
# https://cloud.google.com/vertex-ai/generative-ai/docs/overview
```
## GET CREDENTIALS
Get credentials:
```python Code
import json
file_path = 'path/to/vertex_ai_service_account.json'
# Load the JSON file
with open(file_path, 'r') as file:
vertex_credentials = json.load(file)
# Convert to JSON string
# Convert the credentials to a JSON string
vertex_credentials_json = json.dumps(vertex_credentials)
```
Example usage:
```python Code
from crewai import LLM
llm = LLM(
model="gemini/gemini-1.5-pro-latest",
temperature=0.7,

View File

@@ -368,6 +368,33 @@ my_crew = Crew(
)
```
### Adding Custom Embedding Function
```python Code
from crewai import Crew, Agent, Task, Process
from chromadb import Documents, EmbeddingFunction, Embeddings
# Create a custom embedding function
class CustomEmbedder(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
# generate embeddings
return [1, 2, 3] # this is a dummy embedding
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "custom",
"config": {
"embedder": CustomEmbedder()
}
}
)
```
### Resetting Memory
```shell

View File

@@ -10,6 +10,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata.
"""
embedder_config: Optional[Dict[str, Any]] = None
storage: Any
def __init__(self, storage: Any, **data: Any):

View File

@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: Optional[Any] = None,
embedder_config: Optional[Dict[str, Any]] = None,
crew: Any = None,
):
self.type = type

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, cast
from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
"bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface,
"watson": self._configure_watson,
"custom": self._configure_custom,
}
def configure_embedder(
self,
embedder_config: Dict[str, Any] | None = None,
embedder_config: Optional[Dict[str, Any]] = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
@@ -30,20 +31,19 @@ class EmbeddingConfigurator:
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model")
if isinstance(provider, EmbeddingFunction):
try:
validate_embedding_function(provider)
return provider
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
model_name = config.get("model") if provider != "custom" else None
if provider not in self.embedding_functions:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
return self.embedding_functions[provider](config, model_name)
embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
@staticmethod
def _create_default_embedding_function():
@@ -64,6 +64,13 @@ class EmbeddingConfigurator:
return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
)
@staticmethod
@@ -78,6 +85,10 @@ class EmbeddingConfigurator:
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
)
@staticmethod
@@ -100,6 +111,8 @@ class EmbeddingConfigurator:
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
)
@staticmethod
@@ -111,6 +124,7 @@ class EmbeddingConfigurator:
return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
task_type=config.get("task_type"),
)
@staticmethod
@@ -195,3 +209,28 @@ class EmbeddingConfigurator:
raise e
return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)