mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
Compare commits
3 Commits
devin/1742
...
devin/1742
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2892cf98f1 | ||
|
|
90f508da12 | ||
|
|
6161e6893e |
@@ -955,20 +955,37 @@ class LLM:
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
def _safe_remove_callback(self, callback_list: List[Any], callback: Any) -> None:
|
||||
"""
|
||||
Safely remove a callback from a list, handling the case where it doesn't exist.
|
||||
|
||||
Args:
|
||||
callback_list: The list of callbacks to remove from
|
||||
callback: The callback to remove
|
||||
"""
|
||||
try:
|
||||
callback_list.remove(callback)
|
||||
except ValueError as e:
|
||||
logging.debug(f"Callback {callback} not found in callback list: {e}")
|
||||
pass
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]) -> None:
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
|
||||
Args:
|
||||
callbacks: List of callback functions to set
|
||||
"""
|
||||
with suppress_warnings():
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
self._safe_remove_callback(litellm.success_callback, callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
self._safe_remove_callback(litellm._async_success_callback, callback)
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ class EmbeddingConfigurator:
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
"openrouter": self._configure_openrouter,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
@@ -211,35 +210,6 @@ class EmbeddingConfigurator:
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_openrouter(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||
"""
|
||||
Configure OpenRouter embedding provider.
|
||||
|
||||
Args:
|
||||
config (Dict[str, Any]): Configuration dictionary containing the API key and optional settings.
|
||||
model_name (str): Name of the embedding model to use.
|
||||
|
||||
Returns:
|
||||
OpenAIEmbeddingFunction: Configured OpenRouter embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is not provided in the config or environment.
|
||||
"""
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
api_key = config.get("api_key") or os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("OpenRouter API key must be provided either in config or OPENROUTER_API_KEY environment variable")
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=api_key,
|
||||
api_base=config.get("api_base", "https://openrouter.ai/api/v1"),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from time import sleep
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -443,3 +444,49 @@ def test_tool_execution_error_event():
|
||||
assert event.tool_args == {"param": "test"}
|
||||
assert event.tool_class == failing_tool
|
||||
assert "Tool execution failed!" in event.error
|
||||
|
||||
def test_set_callbacks_with_nonexistent_callback():
|
||||
"""Test that set_callbacks handles the case where a callback doesn't exist in the list."""
|
||||
# Create a mock callback
|
||||
class MockCallback:
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.called = True
|
||||
|
||||
# Create a test callback
|
||||
test_callback = MockCallback()
|
||||
|
||||
# Make sure the callback lists are empty
|
||||
original_success_callbacks = litellm.success_callback.copy()
|
||||
original_async_callbacks = litellm._async_success_callback.copy()
|
||||
|
||||
try:
|
||||
# Clear the callback lists to ensure clean state
|
||||
litellm.success_callback.clear()
|
||||
litellm._async_success_callback.clear()
|
||||
|
||||
# Create an LLM instance
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
# Call set_callbacks with our test callback - this should work without error
|
||||
llm.set_callbacks([test_callback])
|
||||
|
||||
# Now call set_callbacks again - this should also work without error
|
||||
# even though the callback is already in the list
|
||||
llm.set_callbacks([test_callback])
|
||||
|
||||
# Now remove the callback and try to remove it again - this should not raise an error
|
||||
litellm.success_callback.clear()
|
||||
litellm._async_success_callback.clear()
|
||||
|
||||
# This would previously fail with "list.remove(x): x not in list"
|
||||
llm.set_callbacks([test_callback])
|
||||
|
||||
assert True # If we get here, no exception was raised
|
||||
|
||||
finally:
|
||||
# Restore the original callbacks
|
||||
litellm.success_callback = original_success_callbacks
|
||||
litellm._async_success_callback = original_async_callbacks
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
|
||||
def test_openrouter_embedder_configuration():
|
||||
# Setup
|
||||
configurator = EmbeddingConfigurator()
|
||||
mock_openai_embedding = MagicMock()
|
||||
|
||||
with patch(
|
||||
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction",
|
||||
return_value=mock_openai_embedding,
|
||||
) as mock_embedder:
|
||||
# Test with provided config
|
||||
embedder_config = {
|
||||
"provider": "openrouter",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "test-model",
|
||||
},
|
||||
}
|
||||
|
||||
# Execute
|
||||
result = configurator.configure_embedder(embedder_config)
|
||||
|
||||
# Verify
|
||||
assert result == mock_openai_embedding
|
||||
mock_embedder.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
model_name="test-model",
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_embedder_configuration_with_env_var():
|
||||
# Setup
|
||||
configurator = EmbeddingConfigurator()
|
||||
mock_openai_embedding = MagicMock()
|
||||
|
||||
# Test with API key from environment variable
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "env-key"}), \
|
||||
patch(
|
||||
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction",
|
||||
return_value=mock_openai_embedding,
|
||||
) as mock_embedder:
|
||||
# Config without API key
|
||||
embedder_config = {
|
||||
"provider": "openrouter",
|
||||
"config": {
|
||||
"model": "test-model",
|
||||
},
|
||||
}
|
||||
|
||||
# Execute
|
||||
result = configurator.configure_embedder(embedder_config)
|
||||
|
||||
# Verify
|
||||
assert result == mock_openai_embedding
|
||||
mock_embedder.assert_called_once_with(
|
||||
api_key="env-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
model_name="test-model",
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_embedder_configuration_missing_api_key():
|
||||
# Setup
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
# Test without API key
|
||||
with patch.dict(os.environ, {}, clear=True), \
|
||||
patch(
|
||||
"chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction",
|
||||
side_effect=Exception("Should not be called"),
|
||||
):
|
||||
# Config without API key
|
||||
embedder_config = {
|
||||
"provider": "openrouter",
|
||||
"config": {
|
||||
"model": "test-model",
|
||||
},
|
||||
}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError, match="OpenRouter API key must be provided"):
|
||||
configurator.configure_embedder(embedder_config)
|
||||
Reference in New Issue
Block a user