Merge branch 'main' into brandon/fix-google-docs

This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-02-07 16:52:40 -05:00
committed by GitHub
4 changed files with 80 additions and 12 deletions

View File

@@ -368,6 +368,33 @@ my_crew = Crew(
) )
``` ```
### Adding Custom Embedding Function
```python Code
from crewai import Crew, Agent, Task, Process
from chromadb import Documents, EmbeddingFunction, Embeddings
# Create a custom embedding function
class CustomEmbedder(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
# generate embeddings
return [1, 2, 3] # this is a dummy embedding
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "custom",
"config": {
"embedder": CustomEmbedder()
}
}
)
```
### Resetting Memory ### Resetting Memory
```shell ```shell

View File

@@ -10,6 +10,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata. Base class for memory, now supporting agent tags and generic metadata.
""" """
embedder_config: Optional[Dict[str, Any]] = None
storage: Any storage: Any
def __init__(self, storage: Any, **data: Any): def __init__(self, storage: Any, **data: Any):

View File

@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Optional[Any] = None, embedder_config: Optional[Dict[str, Any]] = None,
crew: Any = None, crew: Any = None,
): ):
self.type = type self.type = type

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, cast from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
"bedrock": self._configure_bedrock, "bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface, "huggingface": self._configure_huggingface,
"watson": self._configure_watson, "watson": self._configure_watson,
"custom": self._configure_custom,
} }
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Dict[str, Any] | None = None, embedder_config: Optional[Dict[str, Any]] = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config."""
if embedder_config is None: if embedder_config is None:
@@ -30,20 +31,19 @@ class EmbeddingConfigurator:
provider = embedder_config.get("provider") provider = embedder_config.get("provider")
config = embedder_config.get("config", {}) config = embedder_config.get("config", {})
model_name = config.get("model") model_name = config.get("model") if provider != "custom" else None
if isinstance(provider, EmbeddingFunction):
try:
validate_embedding_function(provider)
return provider
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
if provider not in self.embedding_functions: if provider not in self.embedding_functions:
raise Exception( raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
) )
return self.embedding_functions[provider](config, model_name)
embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
@staticmethod @staticmethod
def _create_default_embedding_function(): def _create_default_embedding_function():
@@ -64,6 +64,13 @@ class EmbeddingConfigurator:
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name, model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
) )
@staticmethod @staticmethod
@@ -78,6 +85,10 @@ class EmbeddingConfigurator:
api_type=config.get("api_type", "azure"), api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"), api_version=config.get("api_version"),
model_name=model_name, model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
) )
@staticmethod @staticmethod
@@ -100,6 +111,8 @@ class EmbeddingConfigurator:
return GoogleVertexEmbeddingFunction( return GoogleVertexEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
) )
@staticmethod @staticmethod
@@ -111,6 +124,7 @@ class EmbeddingConfigurator:
return GoogleGenerativeAiEmbeddingFunction( return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
task_type=config.get("task_type"),
) )
@staticmethod @staticmethod
@@ -195,3 +209,28 @@ class EmbeddingConfigurator:
raise e raise e
return WatsonEmbeddingFunction() return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)