Compare commits

...

4 Commits

Author SHA1 Message Date
Devin AI
5623e2c851 Fix type error and test issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-25 21:03:36 +00:00
Devin AI
50059c7120 Fix import sorting with ruff
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-25 20:58:24 +00:00
Devin AI
335f1dfdf8 Fix lint and type errors
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-25 20:56:22 +00:00
Devin AI
1f2def2cbe Fix Vertex AI embeddings URL typo (publishers/goole -> publishers/google)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-25 20:51:36 +00:00
4 changed files with 152 additions and 9 deletions

View File

@@ -104,16 +104,25 @@ class EmbeddingConfigurator:
@staticmethod
def _configure_vertexai(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
try:
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
)
from crewai.utilities.embedding_functions import (
FixedGoogleVertexEmbeddingFunction,
)
return FixedGoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
)
except ImportError as e:
raise ImportError(
"Google Vertex dependencies are not installed. Please install them to use Vertex embedding."
) from e
@staticmethod
def _configure_google(config, model_name):

View File

@@ -0,0 +1,40 @@
from typing import Any, List, Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
import requests
from chromadb import Documents, Embeddings
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction):
"""
A wrapper around ChromaDB's GoogleVertexEmbeddingFunction that fixes the URL typo
where 'publishers/goole' is incorrectly used instead of 'publishers/google'.
Issue reference: https://github.com/crewaiinc/crewai/issues/2690
"""
def __init__(self,
model_name: str = "textembedding-gecko",
api_key: Optional[str] = None,
**kwargs: Any):
api_key_str = "" if api_key is None else api_key
super().__init__(model_name=model_name, api_key=api_key_str, **kwargs)
self._original_post = requests.post
requests.post = self._patched_post
def __del__(self):
if hasattr(self, '_original_post'):
requests.post = self._original_post
def _patched_post(self, url, *args, **kwargs):
if 'publishers/goole' in url:
url = url.replace('publishers/goole', 'publishers/google')
return self._original_post(url, *args, **kwargs)
def __call__(self, input: Documents) -> Embeddings:
return super().__call__(input)

View File

@@ -0,0 +1,37 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
class TestEmbeddingConfigurator:
@pytest.fixture
def embedding_configurator(self):
return EmbeddingConfigurator()
def test_configure_vertexai(self, embedding_configurator):
with patch('crewai.utilities.embedding_functions.FixedGoogleVertexEmbeddingFunction') as mock_class:
mock_instance = MagicMock()
mock_class.return_value = mock_instance
config = {
"provider": "vertexai",
"config": {
"api_key": "test-key",
"model": "test-model",
"project_id": "test-project",
"region": "test-region"
}
}
result = embedding_configurator.configure_embedder(config)
mock_class.assert_called_once_with(
model_name="test-model",
api_key="test-key",
project_id="test-project",
region="test-region"
)
assert result == mock_instance

View File

@@ -0,0 +1,57 @@
from unittest.mock import MagicMock, patch
import pytest
import requests
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
class TestFixedGoogleVertexEmbeddingFunction:
@pytest.fixture
def embedding_function(self):
with patch('requests.post') as mock_post:
mock_response = MagicMock()
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
mock_post.return_value = mock_response
function = FixedGoogleVertexEmbeddingFunction(
model_name="test-model",
api_key="test-key"
)
yield function, mock_post
if hasattr(function, '_original_post'):
requests.post = function._original_post
def test_url_correction(self, embedding_function):
function, mock_post = embedding_function
typo_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/goole/models/test-model:predict"
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/test-model:predict"
with patch.object(function, '_original_post') as mock_original_post:
mock_response = MagicMock()
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
mock_original_post.return_value = mock_response
response = function._patched_post(typo_url, json={})
mock_original_post.assert_called_once()
call_args = mock_original_post.call_args
assert call_args[0][0] == expected_url
def test_embedding_call(self, embedding_function):
function, mock_post = embedding_function
mock_response = MagicMock()
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
mock_post.return_value = mock_response
embeddings = function(["test text"])
mock_post.assert_called_once()
assert isinstance(embeddings, list)
assert len(embeddings) > 0