mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-23 15:58:30 +00:00
Compare commits
4 Commits
docs/train
...
fix/embedd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d44356c81 | ||
|
|
d48211f7f8 | ||
|
|
8eea0bd502 | ||
|
|
cafac13447 |
@@ -368,6 +368,33 @@ my_crew = Crew(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Adding Custom Embedding Function
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
from crewai import Crew, Agent, Task, Process
|
||||||
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
|
|
||||||
|
# Create a custom embedding function
|
||||||
|
class CustomEmbedder(EmbeddingFunction):
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
# generate embeddings
|
||||||
|
return [1, 2, 3] # this is a dummy embedding
|
||||||
|
|
||||||
|
my_crew = Crew(
|
||||||
|
agents=[...],
|
||||||
|
tasks=[...],
|
||||||
|
process=Process.sequential,
|
||||||
|
memory=True,
|
||||||
|
verbose=True,
|
||||||
|
embedder={
|
||||||
|
"provider": "custom",
|
||||||
|
"config": {
|
||||||
|
"embedder": CustomEmbedder()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Resetting Memory
|
### Resetting Memory
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ class Memory(BaseModel):
|
|||||||
Base class for memory, now supporting agent tags and generic metadata.
|
Base class for memory, now supporting agent tags and generic metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
embedder_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
storage: Any
|
storage: Any
|
||||||
|
|
||||||
def __init__(self, storage: Any, **data: Any):
|
def __init__(self, storage: Any, **data: Any):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: Optional[Any] = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
):
|
):
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
|
|||||||
"bedrock": self._configure_bedrock,
|
"bedrock": self._configure_bedrock,
|
||||||
"huggingface": self._configure_huggingface,
|
"huggingface": self._configure_huggingface,
|
||||||
"watson": self._configure_watson,
|
"watson": self._configure_watson,
|
||||||
|
"custom": self._configure_custom,
|
||||||
}
|
}
|
||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Dict[str, Any] | None = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config."""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
@@ -30,20 +31,19 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
provider = embedder_config.get("provider")
|
provider = embedder_config.get("provider")
|
||||||
config = embedder_config.get("config", {})
|
config = embedder_config.get("config", {})
|
||||||
model_name = config.get("model")
|
model_name = config.get("model") if provider != "custom" else None
|
||||||
|
|
||||||
if isinstance(provider, EmbeddingFunction):
|
|
||||||
try:
|
|
||||||
validate_embedding_function(provider)
|
|
||||||
return provider
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
|
||||||
|
|
||||||
if provider not in self.embedding_functions:
|
if provider not in self.embedding_functions:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||||
)
|
)
|
||||||
return self.embedding_functions[provider](config, model_name)
|
|
||||||
|
embedding_function = self.embedding_functions[provider]
|
||||||
|
return (
|
||||||
|
embedding_function(config)
|
||||||
|
if provider == "custom"
|
||||||
|
else embedding_function(config, model_name)
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_default_embedding_function():
|
def _create_default_embedding_function():
|
||||||
@@ -64,6 +64,13 @@ class EmbeddingConfigurator:
|
|||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
api_base=config.get("api_base", None),
|
||||||
|
api_type=config.get("api_type", None),
|
||||||
|
api_version=config.get("api_version", None),
|
||||||
|
default_headers=config.get("default_headers", None),
|
||||||
|
dimensions=config.get("dimensions", None),
|
||||||
|
deployment_id=config.get("deployment_id", None),
|
||||||
|
organization_id=config.get("organization_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -78,6 +85,10 @@ class EmbeddingConfigurator:
|
|||||||
api_type=config.get("api_type", "azure"),
|
api_type=config.get("api_type", "azure"),
|
||||||
api_version=config.get("api_version"),
|
api_version=config.get("api_version"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
default_headers=config.get("default_headers"),
|
||||||
|
dimensions=config.get("dimensions"),
|
||||||
|
deployment_id=config.get("deployment_id"),
|
||||||
|
organization_id=config.get("organization_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -100,6 +111,8 @@ class EmbeddingConfigurator:
|
|||||||
return GoogleVertexEmbeddingFunction(
|
return GoogleVertexEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
|
project_id=config.get("project_id"),
|
||||||
|
region=config.get("region"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -111,6 +124,7 @@ class EmbeddingConfigurator:
|
|||||||
return GoogleGenerativeAiEmbeddingFunction(
|
return GoogleGenerativeAiEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.get("api_key"),
|
||||||
|
task_type=config.get("task_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -195,3 +209,28 @@ class EmbeddingConfigurator:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
return WatsonEmbeddingFunction()
|
return WatsonEmbeddingFunction()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_custom(config):
|
||||||
|
custom_embedder = config.get("embedder")
|
||||||
|
if isinstance(custom_embedder, EmbeddingFunction):
|
||||||
|
try:
|
||||||
|
validate_embedding_function(custom_embedder)
|
||||||
|
return custom_embedder
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||||
|
elif callable(custom_embedder):
|
||||||
|
try:
|
||||||
|
instance = custom_embedder()
|
||||||
|
if isinstance(instance, EmbeddingFunction):
|
||||||
|
validate_embedding_function(instance)
|
||||||
|
return instance
|
||||||
|
raise ValueError(
|
||||||
|
"Custom embedder does not create an EmbeddingFunction instance"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user