Compare commits

...

4 Commits

Author SHA1 Message Date
Lorenze Jay
2d44356c81 Refine custom embedder configuration support
- Update custom embedder configuration method to handle custom embedding functions
- Modify type hints for embedder configuration
- Remove unused model_name parameter in custom embedder configuration
2025-02-07 13:41:36 -08:00
Lorenze Jay
d48211f7f8 added docs 2025-02-07 12:48:17 -08:00
Lorenze Jay
8eea0bd502 Merge branch 'main' of github.com:crewAIInc/crewAI into fix/embedder-config 2025-02-07 12:45:30 -08:00
Lorenze Jay
cafac13447 Enhance embedding configuration with custom embedder support
- Add support for custom embedding functions in EmbeddingConfigurator
- Update type hints for embedder configuration
- Extend configuration options for various embedding providers
- Add optional embedder configuration to Memory class
2025-02-07 12:41:57 -08:00
4 changed files with 80 additions and 12 deletions

View File

@@ -368,6 +368,33 @@ my_crew = Crew(
) )
``` ```
### Adding Custom Embedding Function
```python Code
from crewai import Crew, Agent, Task, Process
from chromadb import Documents, EmbeddingFunction, Embeddings
# Create a custom embedding function
class CustomEmbedder(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
# generate embeddings
return [1, 2, 3] # this is a dummy embedding
my_crew = Crew(
agents=[...],
tasks=[...],
process=Process.sequential,
memory=True,
verbose=True,
embedder={
"provider": "custom",
"config": {
"embedder": CustomEmbedder()
}
}
)
```
### Resetting Memory ### Resetting Memory
```shell ```shell

View File

@@ -10,6 +10,8 @@ class Memory(BaseModel):
Base class for memory, now supporting agent tags and generic metadata. Base class for memory, now supporting agent tags and generic metadata.
""" """
embedder_config: Optional[Dict[str, Any]] = None
storage: Any storage: Any
def __init__(self, storage: Any, **data: Any): def __init__(self, storage: Any, **data: Any):

View File

@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Optional[Any] = None, embedder_config: Optional[Dict[str, Any]] = None,
crew: Any = None, crew: Any = None,
): ):
self.type = type self.type = type

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, cast from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
"bedrock": self._configure_bedrock, "bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface, "huggingface": self._configure_huggingface,
"watson": self._configure_watson, "watson": self._configure_watson,
"custom": self._configure_custom,
} }
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Dict[str, Any] | None = None, embedder_config: Optional[Dict[str, Any]] = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config."""
if embedder_config is None: if embedder_config is None:
@@ -30,20 +31,19 @@ class EmbeddingConfigurator:
provider = embedder_config.get("provider") provider = embedder_config.get("provider")
config = embedder_config.get("config", {}) config = embedder_config.get("config", {})
model_name = config.get("model") model_name = config.get("model") if provider != "custom" else None
if isinstance(provider, EmbeddingFunction):
try:
validate_embedding_function(provider)
return provider
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
if provider not in self.embedding_functions: if provider not in self.embedding_functions:
raise Exception( raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
) )
return self.embedding_functions[provider](config, model_name)
embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
@staticmethod @staticmethod
def _create_default_embedding_function(): def _create_default_embedding_function():
@@ -64,6 +64,13 @@ class EmbeddingConfigurator:
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name, model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
) )
@staticmethod @staticmethod
@@ -78,6 +85,10 @@ class EmbeddingConfigurator:
api_type=config.get("api_type", "azure"), api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"), api_version=config.get("api_version"),
model_name=model_name, model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
) )
@staticmethod @staticmethod
@@ -100,6 +111,8 @@ class EmbeddingConfigurator:
return GoogleVertexEmbeddingFunction( return GoogleVertexEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
) )
@staticmethod @staticmethod
@@ -111,6 +124,7 @@ class EmbeddingConfigurator:
return GoogleGenerativeAiEmbeddingFunction( return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
task_type=config.get("task_type"),
) )
@staticmethod @staticmethod
@@ -195,3 +209,28 @@ class EmbeddingConfigurator:
raise e raise e
return WatsonEmbeddingFunction() return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)