diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index b2b033322..31b31a836 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -108,6 +108,7 @@ class EmbeddingConfigurator: from chromadb.utils.embedding_functions.google_embedding_function import ( GoogleVertexEmbeddingFunction, ) + from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction return FixedGoogleVertexEmbeddingFunction( diff --git a/src/crewai/utilities/embedding_functions.py b/src/crewai/utilities/embedding_functions.py index b68ed819c..a0b43629f 100644 --- a/src/crewai/utilities/embedding_functions.py +++ b/src/crewai/utilities/embedding_functions.py @@ -1,10 +1,11 @@ -from typing import List, Any +from typing import Any, List, Optional + +import requests from chromadb import Documents, Embeddings from chromadb.utils.embedding_functions.google_embedding_function import ( GoogleVertexEmbeddingFunction, ) -import requests -from urllib.parse import urlparse, parse_qs, urlencode, urlunparse +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction): @@ -17,7 +18,7 @@ class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction): def __init__(self, model_name: str = "textembedding-gecko", - api_key: str = None, + api_key: Optional[str] = None, **kwargs: Any): super().__init__(model_name=model_name, api_key=api_key, **kwargs) diff --git a/tests/utilities/test_embedding_configurator.py b/tests/utilities/test_embedding_configurator.py index 27244324c..4a25ba2eb 100644 --- a/tests/utilities/test_embedding_configurator.py +++ b/tests/utilities/test_embedding_configurator.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from crewai.utilities.embedding_configurator import EmbeddingConfigurator from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction diff --git a/tests/utilities/test_embedding_functions.py b/tests/utilities/test_embedding_functions.py index 5e43b51f2..89039ac31 100644 --- a/tests/utilities/test_embedding_functions.py +++ b/tests/utilities/test_embedding_functions.py @@ -1,6 +1,6 @@ import pytest import requests -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction