mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 21:28:29 +00:00
Compare commits
4 Commits
devin/1764
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5623e2c851 | ||
|
|
50059c7120 | ||
|
|
335f1dfdf8 | ||
|
|
1f2def2cbe |
@@ -104,16 +104,25 @@ class EmbeddingConfigurator:
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
try:
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
from crewai.utilities.embedding_functions import (
|
||||
FixedGoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return FixedGoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Google Vertex dependencies are not installed. Please install them to use Vertex embedding."
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
|
||||
40
src/crewai/utilities/embedding_functions.py
Normal file
40
src/crewai/utilities/embedding_functions.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Any, List, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from chromadb import Documents, Embeddings
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction):
|
||||
"""
|
||||
A wrapper around ChromaDB's GoogleVertexEmbeddingFunction that fixes the URL typo
|
||||
where 'publishers/goole' is incorrectly used instead of 'publishers/google'.
|
||||
|
||||
Issue reference: https://github.com/crewaiinc/crewai/issues/2690
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "textembedding-gecko",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any):
|
||||
api_key_str = "" if api_key is None else api_key
|
||||
super().__init__(model_name=model_name, api_key=api_key_str, **kwargs)
|
||||
|
||||
self._original_post = requests.post
|
||||
requests.post = self._patched_post
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, '_original_post'):
|
||||
requests.post = self._original_post
|
||||
|
||||
def _patched_post(self, url, *args, **kwargs):
|
||||
if 'publishers/goole' in url:
|
||||
url = url.replace('publishers/goole', 'publishers/google')
|
||||
|
||||
return self._original_post(url, *args, **kwargs)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
return super().__call__(input)
|
||||
37
tests/utilities/test_embedding_configurator.py
Normal file
37
tests/utilities/test_embedding_configurator.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||
|
||||
|
||||
class TestEmbeddingConfigurator:
|
||||
@pytest.fixture
|
||||
def embedding_configurator(self):
|
||||
return EmbeddingConfigurator()
|
||||
|
||||
def test_configure_vertexai(self, embedding_configurator):
|
||||
with patch('crewai.utilities.embedding_functions.FixedGoogleVertexEmbeddingFunction') as mock_class:
|
||||
mock_instance = MagicMock()
|
||||
mock_class.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "vertexai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "test-model",
|
||||
"project_id": "test-project",
|
||||
"region": "test-region"
|
||||
}
|
||||
}
|
||||
|
||||
result = embedding_configurator.configure_embedder(config)
|
||||
|
||||
mock_class.assert_called_once_with(
|
||||
model_name="test-model",
|
||||
api_key="test-key",
|
||||
project_id="test-project",
|
||||
region="test-region"
|
||||
)
|
||||
assert result == mock_instance
|
||||
57
tests/utilities/test_embedding_functions.py
Normal file
57
tests/utilities/test_embedding_functions.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||
|
||||
|
||||
class TestFixedGoogleVertexEmbeddingFunction:
|
||||
@pytest.fixture
|
||||
def embedding_function(self):
|
||||
with patch('requests.post') as mock_post:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
function = FixedGoogleVertexEmbeddingFunction(
|
||||
model_name="test-model",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
yield function, mock_post
|
||||
|
||||
if hasattr(function, '_original_post'):
|
||||
requests.post = function._original_post
|
||||
|
||||
def test_url_correction(self, embedding_function):
|
||||
function, mock_post = embedding_function
|
||||
|
||||
typo_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/goole/models/test-model:predict"
|
||||
|
||||
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/test-model:predict"
|
||||
|
||||
with patch.object(function, '_original_post') as mock_original_post:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||
mock_original_post.return_value = mock_response
|
||||
|
||||
response = function._patched_post(typo_url, json={})
|
||||
|
||||
mock_original_post.assert_called_once()
|
||||
call_args = mock_original_post.call_args
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
def test_embedding_call(self, embedding_function):
|
||||
function, mock_post = embedding_function
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
embeddings = function(["test text"])
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) > 0
|
||||
Reference in New Issue
Block a user