Compare commits

..

3 Commits

Author SHA1 Message Date
Devin AI
b15289b5ae Improve code based on PR feedback: add type hints, error handling, docs, and tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-24 15:00:24 +00:00
Devin AI
c340400582 Fix import ordering in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-24 14:52:57 +00:00
Devin AI
49b19a3b2a Fix AttributeError when agent_info is a list in YAML configuration
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-24 14:51:06 +00:00
5 changed files with 137 additions and 270 deletions

View File

@@ -1,11 +0,0 @@
"""
Patches module for CrewAI.
This module contains patches for external dependencies to fix known issues.
Version: 1.0.0
"""
from crewai.patches.litellm_patch import apply_patches, patch_litellm_ollama_pt
__all__ = ["apply_patches", "patch_litellm_ollama_pt"]

View File

@@ -1,186 +0,0 @@
"""
Patch for litellm to fix IndexError in ollama_pt function.
This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs
in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models.
Version: 1.0.0
"""
import json
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
# Set up logging
logger = logging.getLogger(__name__)
# Patch version
PATCH_VERSION = "1.0.0"
class PatchApplicationError(Exception):
"""Exception raised when a patch fails to apply."""
pass
def apply_patches() -> bool:
"""
Apply all patches to fix known issues with dependencies.
Returns:
bool: True if all patches were applied successfully, False otherwise.
"""
success = patch_litellm_ollama_pt()
logger.info(f"LiteLLM ollama_pt patch applied: {success}")
return success
def patch_litellm_ollama_pt() -> bool:
"""
Patch the ollama_pt function in litellm to fix IndexError.
The issue occurs when accessing messages[msg_i].get("tool_calls") without checking
if msg_i is within bounds of the messages list. This happens after tool execution
during the next LLM call.
Returns:
bool: True if the patch was applied successfully, False otherwise.
Raises:
PatchApplicationError: If there's an error during patch application.
"""
try:
# Import the module containing the function to patch
import litellm.litellm_core_utils.prompt_templates.factory as factory
# Define a patched version of the function
def patched_ollama_pt(model: str, messages: List[Dict]) -> Dict[str, Any]:
"""
Patched version of ollama_pt that adds bounds checking.
This fixes the IndexError that occurs when the assistant message is the last
message in the list and msg_i goes out of bounds.
Args:
model: The model name.
messages: The list of messages to process.
Returns:
Dict containing the prompt and images.
"""
user_message_types = {"user", "tool", "function"}
msg_i = 0
images: List[str] = []
prompt = ""
# Handle empty messages list
if not messages:
return {"prompt": prompt, "images": images}
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
msg_i += 1
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
system_content_str, msg_i = factory._handle_ollama_system_message(
messages, prompt, msg_i
)
if system_content_str:
prompt += f"### System:\n{system_content_str}\n\n"
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_content_str += factory.convert_content_list_to_str(messages[msg_i])
msg_i += 1
# Add bounds check before accessing messages[msg_i]
# This is the key fix for the IndexError
if msg_i < len(messages):
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id = call["id"]
function_name = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
)
msg_i += 1
if assistant_content_str:
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
if msg_i == init_msg_i: # prevent infinite loops
raise factory.litellm.BadRequestError(
message=factory.BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
)
response_dict = {
"prompt": prompt,
"images": images,
}
return response_dict
# Replace the original function with our patched version
factory.ollama_pt = patched_ollama_pt
logger.info(f"Successfully applied litellm ollama_pt patch version {PATCH_VERSION}")
return True
except Exception as e:
error_msg = f"Failed to apply litellm ollama_pt patch: {e}"
logger.error(error_msg)
return False
# For backwards compatibility
def patch_litellm() -> bool:
"""
Legacy function for backwards compatibility.
Returns:
bool: True if the patch was applied successfully, False otherwise.
"""
try:
return patch_litellm_ollama_pt()
except Exception as e:
logger.error(f"Failed to apply legacy litellm patch: {e}")
return False

View File

@@ -1,6 +1,6 @@
import inspect
from pathlib import Path
from typing import Any, Callable, Dict, TypeVar, cast
from typing import Any, Callable, Dict, List, TypeVar, Union, cast
import yaml
from dotenv import load_dotenv
@@ -116,13 +116,33 @@ def CrewBase(cls: T) -> T:
def _map_agent_variables(
self,
agent_name: str,
agent_info: Dict[str, Any],
agent_info: Union[Dict[str, Any], List[Dict[str, Any]]],
agents: Dict[str, Callable],
llms: Dict[str, Callable],
tool_functions: Dict[str, Callable],
cache_handler_functions: Dict[str, Callable],
callbacks: Dict[str, Callable],
) -> None:
"""Maps agent variables from configuration to internal state.
Args:
agent_name: Name of the agent.
agent_info: Configuration as a dictionary or list of configurations.
agents: Dictionary of agent functions.
llms: Dictionary of LLM functions.
tool_functions: Dictionary of tool functions.
cache_handler_functions: Dictionary of cache handler functions.
callbacks: Dictionary of callback functions.
Raises:
ValueError: When an empty list is provided as agent_info.
"""
# If agent_info is a list, use the first item as the configuration
if isinstance(agent_info, list):
if not agent_info:
raise ValueError(f"Empty agent configuration list for agent {agent_name}")
agent_info = agent_info[0]
if llm := agent_info.get("llm"):
try:
self.agents_config[agent_name]["llm"] = llms[llm]()

View File

@@ -1,71 +0,0 @@
"""
Test for the litellm patch that fixes the IndexError in ollama_pt function.
"""
import sys
import unittest
from unittest.mock import MagicMock, patch
import litellm
import pytest
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt
from crewai.patches.litellm_patch import patch_litellm_ollama_pt
class TestLitellmPatch(unittest.TestCase):
def test_ollama_pt_patch_fixes_index_error(self):
"""Test that the patch fixes the IndexError in ollama_pt."""
# Create a message list where the assistant message is the last one
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
# Store the original function to restore it after the test
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
try:
# Apply the patch
success = patch_litellm_ollama_pt()
self.assertTrue(success, "Patch application failed")
# Use the function from the module directly to ensure we're using the patched version
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
# Verify the result is as expected
self.assertIn("prompt", result)
self.assertIn("images", result)
self.assertIn("### User:\nHello", result["prompt"])
self.assertIn("### Assistant:\nHi there", result["prompt"])
finally:
# Restore the original function to avoid affecting other tests
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
def test_ollama_pt_patch_with_empty_messages(self):
"""Test that the patch handles empty message lists."""
messages = []
# Store the original function to restore it after the test
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
try:
# Apply the patch
success = patch_litellm_ollama_pt()
self.assertTrue(success, "Patch application failed")
# Use the function from the module directly to ensure we're using the patched version
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
# Verify the result is as expected
self.assertIn("prompt", result)
self.assertIn("images", result)
self.assertEqual("", result["prompt"])
self.assertEqual([], result["images"])
finally:
# Restore the original function to avoid affecting other tests
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,115 @@
import os
import sys
import tempfile
from pathlib import Path
import pytest
import yaml
class TestYamlConfig:
"""Tests for YAML configuration handling."""
def test_list_format_in_yaml(self):
"""Test that list format in YAML is handled correctly."""
# Create a test YAML content with list format
yaml_content = """
test_agent:
- name: test_agent
role: Test Agent
goal: Test Goal
"""
# Parse the YAML content
data = yaml.safe_load(yaml_content)
# Get the agent_info which should be a list
agent_name = "test_agent"
agent_info = data[agent_name]
# Verify it's a list
assert isinstance(agent_info, list)
# Create a function that simulates the behavior of _map_agent_variables
# with our fix applied
def map_agent_variables(agent_name, agent_info):
# This is the fix we implemented
if isinstance(agent_info, list):
if not agent_info:
raise ValueError(f"Empty agent configuration list for agent {agent_name}")
agent_info = agent_info[0]
# Try to access a dictionary method on agent_info
# This would fail with AttributeError if agent_info is still a list
value = agent_info.get("name")
return value
# Call the function - this would raise AttributeError before the fix
result = map_agent_variables(agent_name, agent_info)
def test_empty_list_in_yaml(self):
"""Test that empty list in YAML raises appropriate error."""
# Create a test YAML content with empty list
yaml_content = """
test_agent: []
"""
# Parse the YAML content
data = yaml.safe_load(yaml_content)
# Get the agent_info which should be an empty list
agent_name = "test_agent"
agent_info = data[agent_name]
# Verify it's a list
assert isinstance(agent_info, list)
assert len(agent_info) == 0
# Create a function that simulates the behavior of _map_agent_variables
def map_agent_variables(agent_name, agent_info):
if isinstance(agent_info, list):
if not agent_info:
raise ValueError(f"Empty agent configuration list for agent {agent_name}")
agent_info = agent_info[0]
return agent_info
# Call the function - should raise ValueError
with pytest.raises(ValueError, match=f"Empty agent configuration list for agent {agent_name}"):
map_agent_variables(agent_name, agent_info)
def test_multiple_items_in_list(self):
"""Test that when multiple items are in the list, the first one is used."""
# Create a test YAML content with multiple items in the list
yaml_content = """
test_agent:
- name: first_agent
role: First Agent
goal: First Goal
- name: second_agent
role: Second Agent
goal: Second Goal
"""
# Parse the YAML content
data = yaml.safe_load(yaml_content)
# Get the agent_info which should be a list
agent_name = "test_agent"
agent_info = data[agent_name]
# Verify it's a list with multiple items
assert isinstance(agent_info, list)
assert len(agent_info) > 1
# Create a function that simulates the behavior of _map_agent_variables
def map_agent_variables(agent_name, agent_info):
if isinstance(agent_info, list):
if not agent_info:
raise ValueError(f"Empty agent configuration list for agent {agent_name}")
agent_info = agent_info[0]
return agent_info.get("name")
# Call the function - should return name from the first item
result = map_agent_variables(agent_name, agent_info)
# Verify only the first item was used
assert result == "first_agent"