Compare commits

..

4 Commits

Author SHA1 Message Date
Devin AI
0a22cbc349 Enhance set_callbacks with improved type hints, error handling, and expanded test coverage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:44:06 +00:00
Devin AI
4f5d18a2c9 Fix import sorting with ruff --fix
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:41:32 +00:00
Devin AI
f6571f114d Fix import sorting in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:40:38 +00:00
Devin AI
c06eb56cf3 Fix litellm callback removal error (issue #2513)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:39:10 +00:00
8 changed files with 140 additions and 390 deletions

View File

@@ -589,23 +589,6 @@ class Crew(BaseModel):
self,
inputs: Optional[Dict[str, Any]] = None,
) -> CrewOutput:
"""
Starts the crew to work on its assigned tasks.
This method initializes all agents, sets up their configurations, and executes
the tasks according to the specified process (sequential or hierarchical).
For each agent, if no function_calling_llm is specified:
- Uses the crew's function_calling_llm if available
- Otherwise uses the agent's own LLM for function calling, enabling
non-OpenAI models to work without requiring OpenAI credentials
Args:
inputs: Optional dictionary of inputs to be used in task execution
Returns:
CrewOutput: The result of the crew's execution
"""
try:
for before_callback in self.before_kickoff_callbacks:
if inputs is None:
@@ -635,10 +618,7 @@ class Crew(BaseModel):
agent.set_knowledge(crew_embedder=self.embedder)
# TODO: Create an AgentFunctionCalling protocol for future refactoring
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
if self.function_calling_llm:
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
else:
agent.function_calling_llm = agent.llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"

View File

@@ -956,22 +956,42 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
def set_callbacks(self, callbacks: List[Any]) -> None:
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
This method safely updates the litellm callback lists by:
1. Identifying the types of new callbacks
2. Filtering out existing callbacks of the same types
3. Setting the new callbacks
Args:
callbacks: List of callback objects to set in litellm
Returns:
None
Note:
Uses list comprehension to avoid "list.remove(x): x not in list" errors
that can occur with direct removal during iteration.
"""
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
try:
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
litellm.success_callback = [
cb for cb in litellm.success_callback if type(cb) not in callback_types
]
litellm._async_success_callback = [
cb for cb in litellm._async_success_callback if type(cb) not in callback_types
]
litellm.callbacks = callbacks
except Exception as e:
logging.error(f"Error setting callbacks: {str(e)}")
raise
def set_env_callbacks(self):
"""

View File

@@ -1819,45 +1819,3 @@ def test_litellm_anthropic_error_handling():
# Verify the LLM call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_uses_own_llm_for_function_calling_when_not_specified():
"""
Test that an agent uses its own LLM for function calling when no function_calling_llm
is specified, ensuring that non-OpenAI models like Gemini can be used without
requiring OpenAI API keys.
This test verifies the fix for issue #2517, where users would get OpenAI authentication
errors even when using non-OpenAI models like Gemini.
"""
@tool
def simple_tool(input_text: str) -> str:
"""A simple tool that returns the input text."""
return f"Tool processed: {input_text}"
agent = Agent(
role="Gemini Agent",
goal="Test Gemini model without OpenAI dependency",
backstory="I am a test agent using Gemini model",
llm="gemini/gemini-1.5-flash", # Using Gemini model
verbose=True
)
with patch.object(LLM, 'supports_function_calling', return_value=True):
with patch('crewai.tools.tool_usage.ToolUsage') as mock_tool_usage:
task = Task(
description="Use the simple tool",
expected_output="Tool result",
agent=agent
)
try:
agent.execute_task(task, tools=[simple_tool])
args, kwargs = mock_tool_usage.call_args
assert kwargs['function_calling_llm'] == agent.llm, "Agent should use its own LLM for function calling"
assert kwargs['function_calling_llm'].model.startswith("gemini"), "Function calling LLM should be Gemini"
except Exception as e:
if "OPENAI_API_KEY" in str(e):
pytest.fail("Test failed with OpenAI API key error despite using Gemini model")

View File

@@ -1,74 +0,0 @@
interactions:
- request:
body: '{"contents": [{"role": "user", "parts": [{"text": "\nCurrent Task: Use
the simple tool\n\nThis is the expected criteria for your final answer: Tool
result\nyou MUST return the actual complete content as the final answer, not
a summary.\n\nBegin! This is VERY important to you, use the tools available
and give your best Final Answer, your job depends on it!\n\nThought:"}]}], "system_instruction":
{"parts": [{"text": "You are Gemini Agent. I am a test agent using Gemini model\nYour
personal goal is: Test Gemini model without OpenAI dependency\nYou ONLY have
access to the following tools, and should NEVER make up tools that are not listed
here:\n\nTool Name: simple_tool\nTool Arguments: {''input_text'': {''description'':
None, ''type'': ''str''}}\nTool Description: A simple tool that returns the
input text.\n\nIMPORTANT: Use the following format in your response:\n\n```\nThought:
you should always think about what to do\nAction: the action to take, only one
name of [simple_tool], just the name, exactly as it''s written.\nAction Input:
the input to the action, just a simple JSON object, enclosed in curly braces,
using \" to wrap keys and values.\nObservation: the result of the action\n```\n\nOnce
all necessary information is gathered, return the following format:\n\n```\nThought:
I now know the final answer\nFinal Answer: the final answer to the original
input question\n```"}]}, "generationConfig": {"stop_sequences": ["\nObservation:"]}}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '1447'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
user-agent:
- litellm/1.60.2
method: POST
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=None
response:
content: "{\n \"error\": {\n \"code\": 400,\n \"message\": \"API key not
valid. Please pass a valid API key.\",\n \"status\": \"INVALID_ARGUMENT\",\n
\ \"details\": [\n {\n \"@type\": \"type.googleapis.com/google.rpc.ErrorInfo\",\n
\ \"reason\": \"API_KEY_INVALID\",\n \"domain\": \"googleapis.com\",\n
\ \"metadata\": {\n \"service\": \"generativelanguage.googleapis.com\"\n
\ }\n },\n {\n \"@type\": \"type.googleapis.com/google.rpc.LocalizedMessage\",\n
\ \"locale\": \"en-US\",\n \"message\": \"API key not valid. Please
pass a valid API key.\"\n }\n ]\n }\n}\n"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Content-Encoding:
- gzip
Content-Type:
- application/json; charset=UTF-8
Date:
- Thu, 03 Apr 2025 11:39:05 GMT
Server:
- scaffolding on HTTPServer2
Server-Timing:
- gfet4t7; dur=47
Transfer-Encoding:
- chunked
Vary:
- Origin
- X-Origin
- Referer
X-Content-Type-Options:
- nosniff
X-Frame-Options:
- SAMEORIGIN
X-XSS-Protection:
- '0'
http_version: HTTP/1.1
status_code: 400
version: 1

File diff suppressed because one or more lines are too long

View File

@@ -24,7 +24,6 @@ from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools import tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import Logger
from crewai.utilities.events import (
@@ -4120,52 +4119,3 @@ def test_crew_kickoff_for_each_works_with_manager_agent_copy():
assert crew_copy.manager_agent.backstory == crew.manager_agent.backstory
assert isinstance(crew_copy.manager_agent.agent_executor, CrewAgentExecutor)
assert isinstance(crew_copy.manager_agent.cache_handler, CacheHandler)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_agents_use_own_llm_for_function_calling():
"""
Test that agents in a crew use their own LLM for function calling when no
function_calling_llm is specified for either the agent or the crew.
This test verifies the fix for issue #2517, where users would get OpenAI authentication
errors even when using non-OpenAI models like Gemini. The fix ensures that when no
function_calling_llm is specified, the agent uses its own LLM for function calling.
"""
@tool
def simple_tool(input_text: str) -> str:
"""A simple tool that returns the input text."""
return f"Tool processed: {input_text}"
gemini_agent = Agent(
role="Gemini Agent",
goal="Test Gemini model without OpenAI dependency",
backstory="I am a test agent using Gemini model",
llm="gemini/gemini-1.5-flash", # Using Gemini model
tools=[simple_tool],
verbose=True
)
crew = Crew(
agents=[gemini_agent],
tasks=[
Task(
description="Use the simple tool to process 'test input'",
expected_output="Processed result",
agent=gemini_agent
)
],
verbose=True
)
with patch.object(LLM, 'supports_function_calling', return_value=True):
with patch('crewai.tools.tool_usage.ToolUsage') as mock_tool_usage:
try:
crew.kickoff()
args, kwargs = mock_tool_usage.call_args
assert kwargs['function_calling_llm'] == gemini_agent.llm, "Agent should use its own LLM for function calling"
assert kwargs['function_calling_llm'].model.startswith("gemini"), "Function calling LLM should be Gemini"
except Exception as e:
if "OPENAI_API_KEY" in str(e):
pytest.fail("Test failed with OpenAI API key error despite using Gemini model")

View File

@@ -0,0 +1,105 @@
from typing import Any, List
import litellm
import pytest
from crewai.llm import LLM
class CustomCallback:
"""A simple callback class for testing."""
pass
class DifferentCallback:
"""A different callback class for testing type differentiation."""
pass
@pytest.fixture
def reset_litellm_callbacks():
"""Fixture to reset litellm callbacks after each test."""
original_success_callback = litellm.success_callback
original_async_success_callback = litellm._async_success_callback
yield
litellm.success_callback = original_success_callback
litellm._async_success_callback = original_async_success_callback
def test_set_callbacks_handles_removed_callbacks(reset_litellm_callbacks):
"""Test that set_callbacks handles the case where callbacks are removed during iteration."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
callback1 = CustomCallback()
callback2 = CustomCallback()
litellm.success_callback.append(callback1)
litellm.success_callback.append(callback2)
new_callback = CustomCallback()
litellm.success_callback.remove(callback1)
llm.set_callbacks([new_callback])
assert litellm.callbacks == [new_callback]
assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0
@pytest.mark.parametrize("callback_count", [1, 3, 5])
def test_set_callbacks_with_different_sizes(callback_count, reset_litellm_callbacks):
"""Test with various numbers of callbacks."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
callbacks = [CustomCallback() for _ in range(callback_count)]
for callback in callbacks:
litellm.success_callback.append(callback)
new_callback = CustomCallback()
llm.set_callbacks([new_callback])
assert litellm.callbacks == [new_callback]
assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0
def test_set_callbacks_with_different_types(reset_litellm_callbacks):
"""Test that callbacks of different types are handled correctly."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
custom_callback = CustomCallback()
different_callback = DifferentCallback()
litellm.success_callback.append(custom_callback)
litellm.success_callback.append(different_callback)
llm.set_callbacks([CustomCallback()])
assert any(isinstance(cb, DifferentCallback) for cb in litellm.success_callback)
assert not any(isinstance(cb, CustomCallback) for cb in litellm.success_callback)
def test_set_callbacks_with_empty_list(reset_litellm_callbacks):
"""Test setting callbacks with an empty list."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
custom_callback = CustomCallback()
litellm.success_callback.append(custom_callback)
llm.set_callbacks([])
assert litellm.callbacks == []
assert custom_callback in litellm.success_callback

View File

@@ -395,11 +395,9 @@ def test_tools_emits_error_events():
)
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
with patch.object(LLM, 'supports_function_calling', return_value=True):
crew.kickoff()
crew.kickoff()
assert len(received_events) > 0
assert len(received_events) == 48
assert received_events[0].agent_key == agent.key
assert received_events[0].agent_role == agent.role
assert received_events[0].tool_name == "error_tool"