Compare commits

..

4 Commits

5 changed files with 13 additions and 129 deletions

View File

@@ -3,7 +3,6 @@ import re
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from litellm import AuthenticationError as LiteLLMAuthenticationError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -198,19 +197,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return self._invoke_loop(formatted_answer)
except Exception as e:
if isinstance(e, LiteLLMAuthenticationError):
self._logger.log(
level="error",
message="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
self._logger.log(
level="error",
message=f"Error details: {str(e)}",
color="red",
)
raise e
elif LLMContextLengthExceededException(str(e))._is_context_limit_error(
if LLMContextLengthExceededException(str(e))._is_context_limit_error(
str(e)
):
self._handle_context_length()

View File

@@ -14,13 +14,13 @@ class Knowledge(BaseModel):
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None

View File

@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
default_factory=list, description="The path to the file"
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
storage: Optional[KnowledgeStorage] = Field(default=None)
safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
@@ -62,7 +62,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _save_documents(self):
"""Save the documents to the storage."""
self.storage.save(self.chunks)
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Union[Path, str]) -> Path:
"""Convert a path to a Path object."""

View File

@@ -16,7 +16,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
storage: Optional[KnowledgeStorage] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None)
@@ -46,4 +46,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
"""
self.storage.save(self.chunks)
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -1308,115 +1308,6 @@ def test_llm_call_with_error():
llm.call(messages)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_litellm_auth_error_handling():
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
# Create an agent with a mocked LLM and max_retry_limit=0
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="gpt-4"),
max_retry_limit=0, # Disable retries for authentication errors
max_iter=1, # Limit to one iteration to prevent multiple calls
)
# Create a task
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
# Mock the LLM call to raise LiteLLMAuthenticationError
with (
patch.object(LLM, "call") as mock_llm_call,
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key",
llm_provider="openai",
model="gpt-4"
)
agent.execute_task(task)
# Verify the call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_agent_executor_litellm_auth_error():
"""Test that CrewAgentExecutor properly identifies and handles LiteLLM authentication errors."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
from crewai.utilities import Logger
from crewai.agents.tools_handler import ToolsHandler
# Create an agent and executor with max_retry_limit=0
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="gpt-4"),
max_retry_limit=0, # Disable retries for authentication errors
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
# Create executor with all required parameters
executor = CrewAgentExecutor(
agent=agent,
task=task,
llm=agent.llm,
crew=None,
prompt={
"system": "You are a test agent",
"user": "Execute the task: {input}"
},
max_iter=5,
tools=[],
tools_names="",
stop_words=[],
tools_description="",
tools_handler=ToolsHandler(),
)
# Mock the LLM call to raise LiteLLMAuthenticationError
with (
patch.object(LLM, "call") as mock_llm_call,
patch.object(Logger, "log") as mock_logger,
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key",
llm_provider="openai",
model="gpt-4"
)
executor.invoke({
"input": "test input",
"tool_names": "", # Required template variable
"tools": "", # Required template variable
})
# Verify error handling
mock_logger.assert_any_call(
level="error",
message="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
mock_logger.assert_any_call(
level="error",
message="Error details: litellm.AuthenticationError: Invalid API key",
color="red",
)
# Verify the call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit():
agent = Agent(