diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index e523b60f0..b2b033322 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -104,16 +104,22 @@ class EmbeddingConfigurator: @staticmethod def _configure_vertexai(config, model_name): - from chromadb.utils.embedding_functions.google_embedding_function import ( - GoogleVertexEmbeddingFunction, - ) - - return GoogleVertexEmbeddingFunction( - model_name=model_name, - api_key=config.get("api_key"), - project_id=config.get("project_id"), - region=config.get("region"), - ) + try: + from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleVertexEmbeddingFunction, + ) + from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction + + return FixedGoogleVertexEmbeddingFunction( + model_name=model_name, + api_key=config.get("api_key"), + project_id=config.get("project_id"), + region=config.get("region"), + ) + except ImportError as e: + raise ImportError( + "Google Vertex dependencies are not installed. Please install them to use Vertex embedding." + ) from e @staticmethod def _configure_google(config, model_name): diff --git a/src/crewai/utilities/embedding_functions.py b/src/crewai/utilities/embedding_functions.py new file mode 100644 index 000000000..b68ed819c --- /dev/null +++ b/src/crewai/utilities/embedding_functions.py @@ -0,0 +1,38 @@ +from typing import List, Any +from chromadb import Documents, Embeddings +from chromadb.utils.embedding_functions.google_embedding_function import ( + GoogleVertexEmbeddingFunction, +) +import requests +from urllib.parse import urlparse, parse_qs, urlencode, urlunparse + + +class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction): + """ + A wrapper around ChromaDB's GoogleVertexEmbeddingFunction that fixes the URL typo + where 'publishers/goole' is incorrectly used instead of 'publishers/google'. + + Issue reference: https://github.com/crewaiinc/crewai/issues/2690 + """ + + def __init__(self, + model_name: str = "textembedding-gecko", + api_key: str = None, + **kwargs: Any): + super().__init__(model_name=model_name, api_key=api_key, **kwargs) + + self._original_post = requests.post + requests.post = self._patched_post + + def __del__(self): + if hasattr(self, '_original_post'): + requests.post = self._original_post + + def _patched_post(self, url, *args, **kwargs): + if 'publishers/goole' in url: + url = url.replace('publishers/goole', 'publishers/google') + + return self._original_post(url, *args, **kwargs) + + def __call__(self, input: Documents) -> Embeddings: + return super().__call__(input) diff --git a/tests/utilities/test_embedding_configurator.py b/tests/utilities/test_embedding_configurator.py new file mode 100644 index 000000000..27244324c --- /dev/null +++ b/tests/utilities/test_embedding_configurator.py @@ -0,0 +1,36 @@ +import pytest +from unittest.mock import patch, MagicMock + +from crewai.utilities.embedding_configurator import EmbeddingConfigurator +from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction + + +class TestEmbeddingConfigurator: + @pytest.fixture + def embedding_configurator(self): + return EmbeddingConfigurator() + + def test_configure_vertexai(self, embedding_configurator): + with patch('crewai.utilities.embedding_functions.FixedGoogleVertexEmbeddingFunction') as mock_class: + mock_instance = MagicMock() + mock_class.return_value = mock_instance + + config = { + "provider": "vertexai", + "config": { + "api_key": "test-key", + "model": "test-model", + "project_id": "test-project", + "region": "test-region" + } + } + + result = embedding_configurator.configure_embedder(config) + + mock_class.assert_called_once_with( + model_name="test-model", + api_key="test-key", + project_id="test-project", + region="test-region" + ) + assert result == mock_instance diff --git a/tests/utilities/test_embedding_functions.py b/tests/utilities/test_embedding_functions.py new file mode 100644 index 000000000..5e43b51f2 --- /dev/null +++ b/tests/utilities/test_embedding_functions.py @@ -0,0 +1,51 @@ +import pytest +import requests +from unittest.mock import patch, MagicMock + +from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction + + +class TestFixedGoogleVertexEmbeddingFunction: + @pytest.fixture + def embedding_function(self): + with patch('requests.post') as mock_post: + mock_response = MagicMock() + mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]} + mock_post.return_value = mock_response + + function = FixedGoogleVertexEmbeddingFunction( + model_name="test-model", + api_key="test-key" + ) + + yield function, mock_post + + if hasattr(function, '_original_post'): + requests.post = function._original_post + + def test_url_correction(self, embedding_function): + function, mock_post = embedding_function + + typo_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/goole/models/test-model:predict" + + expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/test-model:predict" + + with patch.object(function, '_original_post') as mock_original_post: + mock_response = MagicMock() + mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]} + mock_original_post.return_value = mock_response + + response = function._patched_post(typo_url, json={}) + + mock_original_post.assert_called_once() + call_args = mock_original_post.call_args + assert call_args[0][0] == expected_url + + def test_embedding_call(self, embedding_function): + function, mock_post = embedding_function + + embeddings = function(["test text"]) + + mock_post.assert_called_once() + + assert isinstance(embeddings, list)