mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Compare commits
3 Commits
devin/1745
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3572ecf1c7 | ||
|
|
4cf90dbcb7 | ||
|
|
7a7736cfc6 |
@@ -62,3 +62,5 @@ def reset_memories_command(
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
if "No crew found" in str(e):
|
||||
click.echo("This error might occur when running the command in a non-CrewAI project directory.", err=True)
|
||||
|
||||
@@ -82,11 +82,16 @@ def _get_project_attribute(
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
|
||||
)
|
||||
if not any(True for dep in dependencies if "crewai" in dep):
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
import inspect
|
||||
calling_frame = inspect.currentframe()
|
||||
if calling_frame and calling_frame.f_back and calling_frame.f_back.f_code:
|
||||
calling_function = calling_frame.f_back.f_code.co_name
|
||||
if calling_function != "reset_memories":
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["project", "dependencies"]) or []
|
||||
)
|
||||
if not any(True for dep in dependencies if "crewai" in dep):
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
|
||||
attribute = _get_nested_value(pyproject_content, keys)
|
||||
except FileNotFoundError:
|
||||
|
||||
@@ -37,7 +37,6 @@ with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
import litellm
|
||||
from litellm import Choices
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
@@ -598,11 +597,6 @@ class LLM(BaseLLM):
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
# Catch context window errors from litellm and convert them to our own exception type.
|
||||
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
if full_response.strip():
|
||||
@@ -717,16 +711,7 @@ class LLM(BaseLLM):
|
||||
str: The response text
|
||||
"""
|
||||
# --- 1) Make the completion call
|
||||
try:
|
||||
# Attempt to make the completion call, but catch context window errors
|
||||
# and convert them to our own exception type for consistent handling
|
||||
# across the codebase. This allows CrewAgentExecutor to handle context
|
||||
# length issues appropriately.
|
||||
response = litellm.completion(**params)
|
||||
except ContextWindowExceededError as e:
|
||||
# Convert litellm's context window error to our own exception type
|
||||
# for consistent handling in the rest of the codebase
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
response = litellm.completion(**params)
|
||||
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
@@ -885,17 +870,15 @@ class LLM(BaseLLM):
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType):
|
||||
|
||||
@@ -104,25 +104,16 @@ class EmbeddingConfigurator:
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
try:
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.utilities.embedding_functions import (
|
||||
FixedGoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return FixedGoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Google Vertex dependencies are not installed. Please install them to use Vertex embedding."
|
||||
) from e
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
from typing import Any, List, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from chromadb import Documents, Embeddings
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class FixedGoogleVertexEmbeddingFunction(GoogleVertexEmbeddingFunction):
|
||||
"""
|
||||
A wrapper around ChromaDB's GoogleVertexEmbeddingFunction that fixes the URL typo
|
||||
where 'publishers/goole' is incorrectly used instead of 'publishers/google'.
|
||||
|
||||
Issue reference: https://github.com/crewaiinc/crewai/issues/2690
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "textembedding-gecko",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any):
|
||||
api_key_str = "" if api_key is None else api_key
|
||||
super().__init__(model_name=model_name, api_key=api_key_str, **kwargs)
|
||||
|
||||
self._original_post = requests.post
|
||||
requests.post = self._patched_post
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, '_original_post'):
|
||||
requests.post = self._original_post
|
||||
|
||||
def _patched_post(self, url, *args, **kwargs):
|
||||
if 'publishers/goole' in url:
|
||||
url = url.replace('publishers/goole', 'publishers/google')
|
||||
|
||||
return self._original_post(url, *args, **kwargs)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
return super().__call__(input)
|
||||
@@ -132,6 +132,15 @@ def test_reset_knowledge(mock_get_crew, runner):
|
||||
assert result.output == "Knowledge has been reset.\n"
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.reset_memories_command.get_crew")
|
||||
def test_reset_knowledge_with_kn_flag(mock_get_crew, runner):
|
||||
mock_crew = mock.Mock()
|
||||
mock_get_crew.return_value = mock_crew
|
||||
result = runner.invoke(reset_memories, ["-kn"])
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="knowledge")
|
||||
assert result.output == "Knowledge has been reset.\n"
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
|
||||
@@ -373,45 +373,6 @@ def get_weather_tool_schema():
|
||||
},
|
||||
}
|
||||
|
||||
def test_context_window_exceeded_error_handling():
|
||||
"""Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededException."""
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
llm = LLM(model="gpt-4")
|
||||
|
||||
# Test non-streaming response
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.side_effect = ContextWindowExceededError(
|
||||
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
|
||||
model="gpt-4",
|
||||
llm_provider="openai"
|
||||
)
|
||||
|
||||
with pytest.raises(LLMContextLengthExceededException) as excinfo:
|
||||
llm.call("This is a test message")
|
||||
|
||||
assert "context length exceeded" in str(excinfo.value).lower()
|
||||
assert "8192 tokens" in str(excinfo.value)
|
||||
|
||||
# Test streaming response
|
||||
llm = LLM(model="gpt-4", stream=True)
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.side_effect = ContextWindowExceededError(
|
||||
"This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.",
|
||||
model="gpt-4",
|
||||
llm_provider="openai"
|
||||
)
|
||||
|
||||
with pytest.raises(LLMContextLengthExceededException) as excinfo:
|
||||
llm.call("This is a test message")
|
||||
|
||||
assert "context length exceeded" in str(excinfo.value).lower()
|
||||
assert "8192 tokens" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||
|
||||
|
||||
class TestEmbeddingConfigurator:
|
||||
@pytest.fixture
|
||||
def embedding_configurator(self):
|
||||
return EmbeddingConfigurator()
|
||||
|
||||
def test_configure_vertexai(self, embedding_configurator):
|
||||
with patch('crewai.utilities.embedding_functions.FixedGoogleVertexEmbeddingFunction') as mock_class:
|
||||
mock_instance = MagicMock()
|
||||
mock_class.return_value = mock_instance
|
||||
|
||||
config = {
|
||||
"provider": "vertexai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "test-model",
|
||||
"project_id": "test-project",
|
||||
"region": "test-region"
|
||||
}
|
||||
}
|
||||
|
||||
result = embedding_configurator.configure_embedder(config)
|
||||
|
||||
mock_class.assert_called_once_with(
|
||||
model_name="test-model",
|
||||
api_key="test-key",
|
||||
project_id="test-project",
|
||||
region="test-region"
|
||||
)
|
||||
assert result == mock_instance
|
||||
@@ -1,57 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai.utilities.embedding_functions import FixedGoogleVertexEmbeddingFunction
|
||||
|
||||
|
||||
class TestFixedGoogleVertexEmbeddingFunction:
|
||||
@pytest.fixture
|
||||
def embedding_function(self):
|
||||
with patch('requests.post') as mock_post:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
function = FixedGoogleVertexEmbeddingFunction(
|
||||
model_name="test-model",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
yield function, mock_post
|
||||
|
||||
if hasattr(function, '_original_post'):
|
||||
requests.post = function._original_post
|
||||
|
||||
def test_url_correction(self, embedding_function):
|
||||
function, mock_post = embedding_function
|
||||
|
||||
typo_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/goole/models/test-model:predict"
|
||||
|
||||
expected_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test-project/locations/us-central1/publishers/google/models/test-model:predict"
|
||||
|
||||
with patch.object(function, '_original_post') as mock_original_post:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||
mock_original_post.return_value = mock_response
|
||||
|
||||
response = function._patched_post(typo_url, json={})
|
||||
|
||||
mock_original_post.assert_called_once()
|
||||
call_args = mock_original_post.call_args
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
def test_embedding_call(self, embedding_function):
|
||||
function, mock_post = embedding_function
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"predictions": [[0.1, 0.2, 0.3]]}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
embeddings = function(["test text"])
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) > 0
|
||||
Reference in New Issue
Block a user