mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Fix Vertex AI embeddings URL typo (publishers/goole -> publishers/google)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -104,16 +104,22 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_vertexai(config, model_name):
|
def _configure_vertexai(config, model_name):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
try:
|
||||||
GoogleVertexEmbeddingFunction,
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
)
|
GoogleVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
return GoogleVertexEmbeddingFunction(
|
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||||
model_name=model_name,
|
|
||||||
api_key=config.get("api_key"),
|
return FixedGoogleVertexEmbeddingFunction(
|
||||||
project_id=config.get("project_id"),
|
model_name=model_name,
|
||||||
region=config.get("region"),
|
api_key=config.get("api_key"),
|
||||||
)
|
project_id=config.get("project_id"),
|
||||||
|
region=config.get("region"),
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Google Vertex dependencies are not installed. Please install them to use Vertex embedding."
|
||||||
|
) from e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_google(config, model_name):
|
def _configure_google(config, model_name):
|
||||||
|
|||||||
38
src/crewai/utilities/embedding_functions.py
Normal file
38
src/crewai/utilities/embedding_functions.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import List, Any
|
||||||
|
from chromadb import Documents, Embeddings
|
||||||
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
|
GoogleVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
|
import requests
|
||||||
|
from urllib.parse import urlparse, parse_qs, urlencode, urlunparse
|
||||||
|
|
||||||
|
|
||||||
|
class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction):
|
||||||
|
"""
|
||||||
|
A wrapper around ChromaDB's GoogleVertexEmbeddingFunction that fixes the URL typo
|
||||||
|
where 'publishers/goole' is incorrectly used instead of 'publishers/google'.
|
||||||
|
|
||||||
|
Issue reference: https://github.com/crewaiinc/crewai/issues/2690
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name: str = "textembedding-gecko",
|
||||||
|
api_key: str = None,
|
||||||
|
**kwargs: Any):
|
||||||
|
super().__init__(model_name=model_name, api_key=api_key, **kwargs)
|
||||||
|
|
||||||
|
self._original_post = requests.post
|
||||||
|
requests.post = self._patched_post
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if hasattr(self, '_original_post'):
|
||||||
|
requests.post = self._original_post
|
||||||
|
|
||||||
|
def _patched_post(self, url, *args, **kwargs):
|
||||||
|
if 'publishers/goole' in url:
|
||||||
|
url = url.replace('publishers/goole', 'publishers/google')
|
||||||
|
|
||||||
|
return self._original_post(url, *args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
return super().__call__(input)
|
||||||
36
tests/utilities/test_embedding_configurator.py
Normal file
36
tests/utilities/test_embedding_configurator.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||||
|
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingConfigurator:
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_configurator(self):
|
||||||
|
return EmbeddingConfigurator()
|
||||||
|
|
||||||
|
def test_configure_vertexai(self, embedding_configurator):
|
||||||
|
with patch('crewai.utilities.embedding_functions.FixedGoogleVertexEmbeddingFunction') as mock_class:
|
||||||
|
mock_instance = MagicMock()
|
||||||
|
mock_class.return_value = mock_instance
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "vertexai",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-key",
|
||||||
|
"model": "test-model",
|
||||||
|
"project_id": "test-project",
|
||||||
|
"region": "test-region"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = embedding_configurator.configure_embedder(config)
|
||||||
|
|
||||||
|
mock_class.assert_called_once_with(
|
||||||
|
model_name="test-model",
|
||||||
|
api_key="test-key",
|
||||||
|
project_id="test-project",
|
||||||
|
region="test-region"
|
||||||
|
)
|
||||||
|
assert result == mock_instance
|
||||||
51
tests/utilities/test_embedding_functions.py
Normal file
51
tests/utilities/test_embedding_functions.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||||
|
|
||||||
|
|
||||||
|
class TestFixedGoogleVertexEmbeddingFunction:
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_function(self):
|
||||||
|
with patch('requests.post') as mock_post:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
function = FixedGoogleVertexEmbeddingFunction(
|
||||||
|
model_name="test-model",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield function, mock_post
|
||||||
|
|
||||||
|
if hasattr(function, '_original_post'):
|
||||||
|
requests.post = function._original_post
|
||||||
|
|
||||||
|
def test_url_correction(self, embedding_function):
|
||||||
|
function, mock_post = embedding_function
|
||||||
|
|
||||||
|
typo_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/goole/models/test-model:predict"
|
||||||
|
|
||||||
|
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/test-model:predict"
|
||||||
|
|
||||||
|
with patch.object(function, '_original_post') as mock_original_post:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||||
|
mock_original_post.return_value = mock_response
|
||||||
|
|
||||||
|
response = function._patched_post(typo_url, json={})
|
||||||
|
|
||||||
|
mock_original_post.assert_called_once()
|
||||||
|
call_args = mock_original_post.call_args
|
||||||
|
assert call_args[0][0] == expected_url
|
||||||
|
|
||||||
|
def test_embedding_call(self, embedding_function):
|
||||||
|
function, mock_post = embedding_function
|
||||||
|
|
||||||
|
embeddings = function(["test text"])
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
assert isinstance(embeddings, list)
|
||||||
Reference in New Issue
Block a user