From 3fc1381e76f359bac23741bd13b86d3d005b4067 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Tue, 14 Oct 2025 15:34:28 -0700 Subject: [PATCH 01/11] feat: enhance AnthropicCompletion class with additional client parameters and tool handling - Added support for client_params in the AnthropicCompletion class to allow for additional client configuration. - Refactored client initialization to use a dedicated method for retrieving client parameters. - Implemented a new method to handle tool use conversation flow, ensuring proper execution and response handling. - Introduced comprehensive test cases to validate the functionality of the AnthropicCompletion class, including tool use scenarios and parameter handling. --- .../llms/providers/anthropic/completion.py | 257 +++++-- .../tests/llms/anthropic/test_anthropic.py | 660 ++++++++++++++++++ 2 files changed, 878 insertions(+), 39 deletions(-) create mode 100644 lib/crewai/tests/llms/anthropic/test_anthropic.py diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 691490dd2..a90f06573 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -40,6 +40,7 @@ class AnthropicCompletion(BaseLLM): top_p: float | None = None, stop_sequences: list[str] | None = None, stream: bool = False, + client_params: dict[str, Any] | None = None, **kwargs, ): """Initialize Anthropic chat completion client. @@ -55,19 +56,20 @@ class AnthropicCompletion(BaseLLM): top_p: Nucleus sampling parameter stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) stream: Enable streaming responses + client_params: Additional parameters for the Anthropic client **kwargs: Additional parameters """ super().__init__( model=model, temperature=temperature, stop=stop_sequences or [], **kwargs ) - # Initialize Anthropic client - self.client = Anthropic( - api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), - base_url=base_url, - timeout=timeout, - max_retries=max_retries, - ) + # Client params + self.client_params = client_params + self.base_url = base_url + self.timeout = timeout + self.max_retries = max_retries + + self.client = Anthropic(**self._get_client_params()) # Store completion parameters self.max_tokens = max_tokens @@ -79,6 +81,26 @@ class AnthropicCompletion(BaseLLM): self.is_claude_3 = "claude-3" in model.lower() self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use + def _get_client_params(self) -> dict[str, Any]: + """Get client parameters.""" + + if self.api_key is None: + self.api_key = os.getenv("ANTHROPIC_API_KEY") + if self.api_key is None: + raise ValueError("ANTHROPIC_API_KEY is required") + + client_params = { + "api_key": self.api_key, + "base_url": self.base_url, + "timeout": self.timeout, + "max_retries": self.max_retries, + } + + if self.client_params: + client_params.update(self.client_params) + + return client_params + def call( self, messages: str | list[dict[str, str]], @@ -102,6 +124,7 @@ class AnthropicCompletion(BaseLLM): Chat completion response or tool call result """ try: + print("we are calling", messages) # Emit call started event self._emit_call_started_event( messages=messages, @@ -121,6 +144,7 @@ class AnthropicCompletion(BaseLLM): completion_params = self._prepare_completion_params( formatted_messages, system_message, tools ) + print("completion_params", completion_params) # Handle streaming vs non-streaming if self.stream: @@ -183,12 +207,25 @@ class AnthropicCompletion(BaseLLM): def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]: """Convert CrewAI tool format to Anthropic tool use format.""" - from crewai.llms.providers.utils.common import safe_tool_conversion - anthropic_tools = [] for tool in tools: - name, description, parameters = safe_tool_conversion(tool, "Anthropic") + if "input_schema" in tool and "name" in tool and "description" in tool: + anthropic_tools.append(tool) + continue + + try: + from crewai.llms.providers.utils.common import safe_tool_conversion + + name, description, parameters = safe_tool_conversion(tool, "Anthropic") + except (ImportError, Exception): + name = tool.get("name", "unknown_tool") + description = tool.get("description", "A tool function") + parameters = ( + tool.get("input_schema") + or tool.get("parameters") + or tool.get("schema") + ) anthropic_tool = { "name": name, @@ -196,7 +233,13 @@ class AnthropicCompletion(BaseLLM): } if parameters and isinstance(parameters, dict): - anthropic_tool["input_schema"] = parameters # type: ignore + anthropic_tool["input_schema"] = parameters + else: + anthropic_tool["input_schema"] = { + "type": "object", + "properties": {}, + "required": [], + } anthropic_tools.append(anthropic_tool) @@ -229,13 +272,11 @@ class AnthropicCompletion(BaseLLM): content = message.get("content", "") if role == "system": - # Extract system message - Anthropic handles it separately if system_message: system_message += f"\n\n{content}" else: system_message = content else: - # Add user/assistant messages - ensure both role and content are str, not None role_str = role if role is not None else "user" content_str = content if content is not None else "" formatted_messages.append({"role": role_str, "content": content_str}) @@ -259,6 +300,7 @@ class AnthropicCompletion(BaseLLM): ) -> str | Any: """Handle non-streaming message completion.""" try: + print("params", params) response: Message = self.client.messages.create(**params) except Exception as e: @@ -270,22 +312,22 @@ class AnthropicCompletion(BaseLLM): usage = self._extract_anthropic_token_usage(response) self._track_token_usage_internal(usage) + # Check if Claude wants to use tools if response.content and available_functions: - for content_block in response.content: - if isinstance(content_block, ToolUseBlock): - function_name = content_block.name - function_args = content_block.input + tool_uses = [ + block for block in response.content if isinstance(block, ToolUseBlock) + ] - result = self._handle_tool_execution( - function_name=function_name, - function_args=function_args, # type: ignore - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - ) - - if result is not None: - return result + if tool_uses: + # Handle tool use conversation flow + return self._handle_tool_use_conversation( + response, + tool_uses, + params, + available_functions, + from_task, + from_agent, + ) # Extract text content content = "" @@ -350,26 +392,54 @@ class AnthropicCompletion(BaseLLM): # Handle completed tool uses if tool_uses and available_functions: - for tool_data in tool_uses.values(): - function_name = tool_data["name"] - + # Convert streamed tool uses to ToolUseBlock-like objects for consistency + tool_use_blocks = [] + for tool_id, tool_data in tool_uses.items(): try: function_args = json.loads(tool_data["input"]) except json.JSONDecodeError as e: logging.error(f"Failed to parse streamed tool arguments: {e}") continue - # Execute tool - result = self._handle_tool_execution( - function_name=function_name, - function_args=function_args, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, + # Create a mock ToolUseBlock-like object + class MockToolUse: + def __init__(self, tool_id: str, name: str, input_args: dict): + self.id = tool_id + self.name = name + self.input = input_args + + tool_use_blocks.append( + MockToolUse(tool_id, tool_data["name"], function_args) ) - if result is not None: - return result + if tool_use_blocks: + # Create a mock response object for the tool conversation flow + class MockResponse: + def __init__(self, content_blocks): + self.content = content_blocks + + # Combine text content and tool uses in the response + response_content = [] + if full_response.strip(): # Add text content if any + + class MockTextBlock: + def __init__(self, text: str): + self.text = text + + response_content.append(MockTextBlock(full_response)) + + response_content.extend(tool_use_blocks) + mock_response = MockResponse(response_content) + + # Handle tool use conversation flow + return self._handle_tool_use_conversation( + mock_response, + tool_use_blocks, + params, + available_functions, + from_task, + from_agent, + ) # Apply stop words to full response full_response = self._apply_stop_words(full_response) @@ -385,6 +455,115 @@ class AnthropicCompletion(BaseLLM): return full_response + def _handle_tool_use_conversation( + self, + initial_response: Message + | Any, # Can be Message or mock response from streaming + tool_uses: list[ToolUseBlock] + | list[Any], # Can be ToolUseBlock or mock objects + params: dict[str, Any], + available_functions: dict[str, Any], + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str: + """Handle the complete tool use conversation flow. + + This implements the proper Anthropic tool use pattern: + 1. Claude requests tool use + 2. We execute the tools + 3. We send tool results back to Claude + 4. Claude processes results and generates final response + """ + # Execute all requested tools and collect results + tool_results = [] + + for tool_use in tool_uses: + function_name = tool_use.name + function_args = tool_use.input + + # Execute the tool + result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, # type: ignore + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + # Create tool result in Anthropic format + tool_result = { + "type": "tool_result", + "tool_use_id": tool_use.id, + "content": str(result) + if result is not None + else "Tool execution completed", + } + tool_results.append(tool_result) + + # Prepare follow-up conversation with tool results + follow_up_params = params.copy() + + # Add Claude's tool use response to conversation + assistant_message = {"role": "assistant", "content": initial_response.content} + + # Add user message with tool results + user_message = {"role": "user", "content": tool_results} + + # Update messages for follow-up call + follow_up_params["messages"] = params["messages"] + [ + assistant_message, + user_message, + ] + + try: + # Send tool results back to Claude for final response + final_response: Message = self.client.messages.create(**follow_up_params) + + # Track token usage for follow-up call + follow_up_usage = self._extract_anthropic_token_usage(final_response) + self._track_token_usage_internal(follow_up_usage) + + # Extract final text content + final_content = "" + if final_response.content: + for content_block in final_response.content: + if hasattr(content_block, "text"): + final_content += content_block.text + + final_content = self._apply_stop_words(final_content) + + # Emit completion event for the final response + self._emit_call_completed_event( + response=final_content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=follow_up_params["messages"], + ) + + # Log combined token usage + total_usage = { + "input_tokens": follow_up_usage.get("input_tokens", 0), + "output_tokens": follow_up_usage.get("output_tokens", 0), + "total_tokens": follow_up_usage.get("total_tokens", 0), + } + + if total_usage.get("total_tokens", 0) > 0: + logging.info(f"Anthropic API tool conversation usage: {total_usage}") + + return final_content + + except Exception as e: + if is_context_length_exceeded(e): + logging.error(f"Context window exceeded in tool follow-up: {e}") + raise LLMContextLengthExceededError(str(e)) from e + + logging.error(f"Tool follow-up conversation failed: {e}") + # Fallback: return the first tool result if follow-up fails + if tool_results: + return tool_results[0]["content"] + raise e + def supports_function_calling(self) -> bool: """Check if the model supports function calling.""" return self.supports_tools diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py new file mode 100644 index 000000000..7d0780561 --- /dev/null +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -0,0 +1,660 @@ +import os +import sys +import types +from unittest.mock import patch, MagicMock +import pytest + +from crewai.llm import LLM +from crewai.llms.providers.anthropic.completion import AnthropicCompletion +from crewai.crew import Crew +from crewai.agent import Agent +from crewai.task import Task +from crewai.cli.constants import DEFAULT_LLM_MODEL + + +def test_anthropic_completion_is_used_when_anthropic_provider(): + """ + Test that AnthropicCompletion from completion.py is used when LLM uses provider 'anthropic' + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + assert llm.__class__.__name__ == "AnthropicCompletion" + assert llm.provider == "anthropic" + assert llm.model == "claude-3-5-sonnet-20241022" + + +def test_anthropic_completion_is_used_when_claude_provider(): + """ + Test that AnthropicCompletion is used when provider is 'claude' + """ + llm = LLM(model="claude/claude-3-5-sonnet-20241022") + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + assert llm.provider == "claude" + assert llm.model == "claude-3-5-sonnet-20241022" + + + + +def test_anthropic_tool_use_conversation_flow(): + """ + Test that the Anthropic completion properly handles tool use conversation flow + """ + from unittest.mock import Mock, patch + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + from anthropic.types.tool_use_block import ToolUseBlock + + # Create AnthropicCompletion instance + completion = AnthropicCompletion(model="claude-3-5-sonnet-20241022") + + # Mock tool function + def mock_weather_tool(location: str) -> str: + return f"The weather in {location} is sunny and 75°F" + + available_functions = {"get_weather": mock_weather_tool} + + # Mock the Anthropic client responses + with patch.object(completion.client.messages, 'create') as mock_create: + # Mock initial response with tool use - need to properly mock ToolUseBlock + mock_tool_use = Mock(spec=ToolUseBlock) + mock_tool_use.id = "tool_123" + mock_tool_use.name = "get_weather" + mock_tool_use.input = {"location": "San Francisco"} + + mock_initial_response = Mock() + mock_initial_response.content = [mock_tool_use] + mock_initial_response.usage = Mock() + mock_initial_response.usage.input_tokens = 100 + mock_initial_response.usage.output_tokens = 50 + + # Mock final response after tool result - properly mock text content + mock_text_block = Mock() + # Set the text attribute as a string, not another Mock + mock_text_block.configure_mock(text="Based on the weather data, it's a beautiful day in San Francisco with sunny skies and 75°F temperature.") + + mock_final_response = Mock() + mock_final_response.content = [mock_text_block] + mock_final_response.usage = Mock() + mock_final_response.usage.input_tokens = 150 + mock_final_response.usage.output_tokens = 75 + + # Configure mock to return different responses on successive calls + mock_create.side_effect = [mock_initial_response, mock_final_response] + + # Test the call + messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] + result = completion.call( + messages=messages, + available_functions=available_functions + ) + + # Verify the result contains the final response + assert "beautiful day in San Francisco" in result + assert "sunny skies" in result + assert "75°F" in result + + # Verify that two API calls were made (initial + follow-up) + assert mock_create.call_count == 2 + + # Verify the second call includes tool results + second_call_args = mock_create.call_args_list[1][1] # kwargs of second call + messages_in_second_call = second_call_args["messages"] + + # Should have original user message + assistant tool use + user tool result + assert len(messages_in_second_call) == 3 + assert messages_in_second_call[0]["role"] == "user" + assert messages_in_second_call[1]["role"] == "assistant" + assert messages_in_second_call[2]["role"] == "user" + + # Verify tool result format + tool_result = messages_in_second_call[2]["content"][0] + assert tool_result["type"] == "tool_result" + assert tool_result["tool_use_id"] == "tool_123" + assert "sunny and 75°F" in tool_result["content"] + + +def test_anthropic_completion_module_is_imported(): + """ + Test that the completion module is properly imported when using Anthropic provider + """ + module_name = "crewai.llms.providers.anthropic.completion" + + # Remove module from cache if it exists + if module_name in sys.modules: + del sys.modules[module_name] + + # Create LLM instance - this should trigger the import + LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Verify the module was imported + assert module_name in sys.modules + completion_mod = sys.modules[module_name] + assert isinstance(completion_mod, types.ModuleType) + + # Verify the class exists in the module + assert hasattr(completion_mod, 'AnthropicCompletion') + + +def test_fallback_to_litellm_when_native_anthropic_fails(): + """ + Test that LLM falls back to LiteLLM when native Anthropic completion fails + """ + # Mock the _get_native_provider to return a failing class + with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider: + + class FailingCompletion: + def __init__(self, *args, **kwargs): + raise Exception("Native Anthropic SDK failed") + + mock_get_provider.return_value = FailingCompletion + + # This should fall back to LiteLLM + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Check that it's using LiteLLM + assert hasattr(llm, 'is_litellm') + assert llm.is_litellm == True + + +def test_anthropic_completion_initialization_parameters(): + """ + Test that AnthropicCompletion is initialized with correct parameters + """ + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + temperature=0.7, + max_tokens=2000, + top_p=0.9, + api_key="test-key" + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + assert llm.model == "claude-3-5-sonnet-20241022" + assert llm.temperature == 0.7 + assert llm.max_tokens == 2000 + assert llm.top_p == 0.9 + + +def test_anthropic_specific_parameters(): + """ + Test Anthropic-specific parameters like stop_sequences and streaming + """ + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + stop_sequences=["Human:", "Assistant:"], + stream=True, + max_retries=5, + timeout=60 + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stream == True + assert llm.client.max_retries == 5 + assert llm.client.timeout == 60 + + +def test_anthropic_completion_call(): + """ + Test that AnthropicCompletion call method works + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the call method on the instance + with patch.object(llm, 'call', return_value="Hello! I'm Claude, ready to help.") as mock_call: + result = llm.call("Hello, how are you?") + + assert result == "Hello! I'm Claude, ready to help." + mock_call.assert_called_once_with("Hello, how are you?") + + +def test_anthropic_completion_called_during_crew_execution(): + """ + Test that AnthropicCompletion.call is actually invoked when running a crew + """ + # Create the LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the call method on the specific instance + with patch.object(anthropic_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call: + + # Create agent with explicit LLM configuration + agent = Agent( + role="Research Assistant", + goal="Find population info", + backstory="You research populations.", + llm=anthropic_llm, + ) + + task = Task( + description="Find Tokyo population", + expected_output="Population number", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + result = crew.kickoff() + + # Verify mock was called + assert mock_call.called + assert "14 million" in str(result) + + +def test_anthropic_completion_call_arguments(): + """ + Test that AnthropicCompletion.call is invoked with correct arguments + """ + # Create LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the instance method + with patch.object(anthropic_llm, 'call') as mock_call: + mock_call.return_value = "Task completed successfully." + + agent = Agent( + role="Test Agent", + goal="Complete a simple task", + backstory="You are a test agent.", + llm=anthropic_llm # Use same instance + ) + + task = Task( + description="Say hello world", + expected_output="Hello world", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + # Verify call was made + assert mock_call.called + + # Check the arguments passed to the call method + call_args = mock_call.call_args + assert call_args is not None + + # The first argument should be the messages + messages = call_args[0][0] # First positional argument + assert isinstance(messages, (str, list)) + + # Verify that the task description appears in the messages + if isinstance(messages, str): + assert "hello world" in messages.lower() + elif isinstance(messages, list): + message_content = str(messages).lower() + assert "hello world" in message_content + + +def test_multiple_anthropic_calls_in_crew(): + """ + Test that AnthropicCompletion.call is invoked multiple times for multiple tasks + """ + # Create LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the instance method + with patch.object(anthropic_llm, 'call') as mock_call: + mock_call.return_value = "Task completed." + + agent = Agent( + role="Multi-task Agent", + goal="Complete multiple tasks", + backstory="You can handle multiple tasks.", + llm=anthropic_llm # Use same instance + ) + + task1 = Task( + description="First task", + expected_output="First result", + agent=agent, + ) + + task2 = Task( + description="Second task", + expected_output="Second result", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task1, task2] + ) + crew.kickoff() + + # Verify multiple calls were made + assert mock_call.call_count >= 2 # At least one call per task + + # Verify each call had proper arguments + for call in mock_call.call_args_list: + assert len(call[0]) > 0 # Has positional arguments + messages = call[0][0] + assert messages is not None + + +def test_anthropic_completion_with_tools(): + """ + Test that AnthropicCompletion.call is invoked with tools when agent has tools + """ + from crewai.tools import tool + + @tool + def sample_tool(query: str) -> str: + """A sample tool for testing""" + return f"Tool result for: {query}" + + # Create LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the instance method + with patch.object(anthropic_llm, 'call') as mock_call: + mock_call.return_value = "Task completed with tools." + + agent = Agent( + role="Tool User", + goal="Use tools to complete tasks", + backstory="You can use tools.", + llm=anthropic_llm, # Use same instance + tools=[sample_tool] + ) + + task = Task( + description="Use the sample tool", + expected_output="Tool usage result", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + assert mock_call.called + + call_args = mock_call.call_args + call_kwargs = call_args[1] if len(call_args) > 1 else {} + + if 'tools' in call_kwargs: + assert call_kwargs['tools'] is not None + assert len(call_kwargs['tools']) > 0 + + +def test_anthropic_raises_error_when_model_not_supported(): + """Test that AnthropicCompletion raises ValueError when model not supported""" + + # Mock the Anthropic client to raise an error + with patch('crewai.llms.providers.anthropic.completion.Anthropic') as mock_anthropic_class: + mock_client = MagicMock() + mock_anthropic_class.return_value = mock_client + + # Mock the error that Anthropic would raise for unsupported models + from anthropic import NotFoundError + mock_client.messages.create.side_effect = NotFoundError( + message="The model `model-doesnt-exist` does not exist", + response=MagicMock(), + body={} + ) + + llm = LLM(model="anthropic/model-doesnt-exist") + + with pytest.raises(Exception): # Should raise some error for unsupported model + llm.call("Hello") + + +def test_anthropic_client_params_setup(): + """ + Test that client_params are properly merged with default client parameters + """ + # Use only valid Anthropic client parameters + custom_client_params = { + "default_headers": {"X-Custom-Header": "test-value"}, + } + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + base_url="https://custom-api.com", + timeout=45, + max_retries=5, + client_params=custom_client_params + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + assert llm.client_params == custom_client_params + + merged_params = llm._get_client_params() + + assert merged_params["api_key"] == "test-key" + assert merged_params["base_url"] == "https://custom-api.com" + assert merged_params["timeout"] == 45 + assert merged_params["max_retries"] == 5 + + assert merged_params["default_headers"] == {"X-Custom-Header": "test-value"} + + +def test_anthropic_client_params_override_defaults(): + """ + Test that client_params can override default client parameters + """ + override_client_params = { + "timeout": 120, # Override the timeout parameter + "max_retries": 10, # Override the max_retries parameter + "default_headers": {"X-Override": "true"} # Valid custom parameter + } + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + timeout=30, + max_retries=3, + client_params=override_client_params + ) + + # Verify this is actually AnthropicCompletion, not LiteLLM fallback + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + merged_params = llm._get_client_params() + + # client_params should override the individual parameters + assert merged_params["timeout"] == 120 + assert merged_params["max_retries"] == 10 + assert merged_params["default_headers"] == {"X-Override": "true"} + + +def test_anthropic_client_params_none(): + """ + Test that client_params=None works correctly (no additional parameters) + """ + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + base_url="https://api.anthropic.com", + timeout=60, + max_retries=2, + client_params=None + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + assert llm.client_params is None + + merged_params = llm._get_client_params() + + expected_keys = {"api_key", "base_url", "timeout", "max_retries"} + assert set(merged_params.keys()) == expected_keys + + # Fixed assertions - all should be inside the with block and use correct values + assert merged_params["api_key"] == "test-key" # Not "test-anthropic-key" + assert merged_params["base_url"] == "https://api.anthropic.com" + assert merged_params["timeout"] == 60 + assert merged_params["max_retries"] == 2 + + +def test_anthropic_client_params_empty_dict(): + """ + Test that client_params={} works correctly (empty additional parameters) + """ + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + client_params={} + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + assert llm.client_params == {} + + merged_params = llm._get_client_params() + + assert "api_key" in merged_params + assert merged_params["api_key"] == "test-key" + + +def test_anthropic_model_detection(): + """ + Test that various Anthropic model formats are properly detected + """ + # Test Anthropic model naming patterns that actually work with provider detection + anthropic_test_cases = [ + "anthropic/claude-3-5-sonnet-20241022", + "claude/claude-3-5-sonnet-20241022" + ] + + for model_name in anthropic_test_cases: + llm = LLM(model=model_name) + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion), f"Failed for model: {model_name}" + + +def test_anthropic_supports_stop_words(): + """ + Test that Anthropic models support stop sequences + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + assert llm.supports_stop_words() == True + + +def test_anthropic_context_window_size(): + """ + Test that Anthropic models return correct context window sizes + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + context_size = llm.get_context_window_size() + + # Should return a reasonable context window size (Claude 3.5 has 200k tokens) + assert context_size > 100000 # Should be substantial + assert context_size <= 200000 # But not exceed the actual limit + + +def test_anthropic_message_formatting(): + """ + Test that messages are properly formatted for Anthropic API + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Test message formatting + test_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + + formatted_messages, system_message = llm._format_messages_for_anthropic(test_messages) + + # System message should be extracted + assert system_message == "You are a helpful assistant." + + # Remaining messages should start with user + assert formatted_messages[0]["role"] == "user" + assert len(formatted_messages) >= 3 # Should have user, assistant, user messages + + +def test_anthropic_streaming_parameter(): + """ + Test that streaming parameter is properly handled + """ + # Test non-streaming + llm_no_stream = LLM(model="anthropic/claude-3-5-sonnet-20241022", stream=False) + assert llm_no_stream.stream == False + + # Test streaming + llm_stream = LLM(model="anthropic/claude-3-5-sonnet-20241022", stream=True) + assert llm_stream.stream == True + + +def test_anthropic_tool_conversion(): + """ + Test that tools are properly converted to Anthropic format + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock tool in CrewAI format + crewai_tools = [{ + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + } + }] + + # Test tool conversion + anthropic_tools = llm._convert_tools_for_interference(crewai_tools) + + assert len(anthropic_tools) == 1 + assert anthropic_tools[0]["name"] == "test_tool" + assert anthropic_tools[0]["description"] == "A test tool" + assert "input_schema" in anthropic_tools[0] + + +def test_anthropic_environment_variable_api_key(): + """ + Test that Anthropic API key is properly loaded from environment + """ + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}): + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + assert llm.client is not None + assert hasattr(llm.client, 'messages') + + +def test_anthropic_token_usage_tracking(): + """ + Test that token usage is properly tracked for Anthropic responses + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the Anthropic response with usage information + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=50, output_tokens=25) + mock_create.return_value = mock_response + + result = llm.call("Hello") + + # Verify the response + assert result == "test response" + + # Verify token usage was extracted + usage = llm._extract_anthropic_token_usage(mock_response) + assert usage["input_tokens"] == 50 + assert usage["output_tokens"] == 25 + assert usage["total_tokens"] == 75 From 7045ed389ae5bd9b4a3fc29002d4bc238772c88a Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Tue, 14 Oct 2025 15:36:30 -0700 Subject: [PATCH 02/11] drop print statements --- lib/crewai/src/crewai/llms/providers/anthropic/completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index a90f06573..ffcaf3077 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -124,7 +124,6 @@ class AnthropicCompletion(BaseLLM): Chat completion response or tool call result """ try: - print("we are calling", messages) # Emit call started event self._emit_call_started_event( messages=messages, @@ -144,7 +143,6 @@ class AnthropicCompletion(BaseLLM): completion_params = self._prepare_completion_params( formatted_messages, system_message, tools ) - print("completion_params", completion_params) # Handle streaming vs non-streaming if self.stream: @@ -300,7 +298,6 @@ class AnthropicCompletion(BaseLLM): ) -> str | Any: """Handle non-streaming message completion.""" try: - print("params", params) response: Message = self.client.messages.create(**params) except Exception as e: From 97c2cbd11069d4cef4db37f7bd1d70e967af39e1 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Wed, 15 Oct 2025 11:12:35 -0700 Subject: [PATCH 03/11] test: add fixture to mock ANTHROPIC_API_KEY for tests - Introduced a pytest fixture to automatically mock the ANTHROPIC_API_KEY environment variable for all tests in the test_anthropic.py module. - This change ensures that tests can run without requiring a real API key, improving test isolation and reliability. --- lib/crewai/tests/llms/anthropic/test_anthropic.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 7d0780561..90a0eb766 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -5,11 +5,16 @@ from unittest.mock import patch, MagicMock import pytest from crewai.llm import LLM -from crewai.llms.providers.anthropic.completion import AnthropicCompletion from crewai.crew import Crew from crewai.agent import Agent from crewai.task import Task -from crewai.cli.constants import DEFAULT_LLM_MODEL + + +@pytest.fixture(autouse=True) +def mock_anthropic_api_key(): + """Automatically mock ANTHROPIC_API_KEY for all tests in this module.""" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + yield def test_anthropic_completion_is_used_when_anthropic_provider(): From 38e7a37485590ff576a3e4d0da72f0b9ef78e2c2 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Wed, 15 Oct 2025 18:57:27 -0700 Subject: [PATCH 04/11] feat: enhance GeminiCompletion class with additional client parameters and refactor client initialization - Added support for client_params in the GeminiCompletion class to allow for additional client configuration. - Refactored client initialization to use a dedicated method for retrieving client parameters, improving code organization and clarity. - Introduced comprehensive test cases to validate the functionality of the GeminiCompletion class, ensuring proper handling of tool use and parameter management. --- .../llms/providers/gemini/completion.py | 107 ++- lib/crewai/tests/llms/google/test_google.py | 644 ++++++++++++++++++ 2 files changed, 730 insertions(+), 21 deletions(-) create mode 100644 lib/crewai/tests/llms/google/test_google.py diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index 7012e5ca0..987a55b49 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -11,9 +11,9 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( try: - from google import genai # type: ignore - from google.genai import types # type: ignore - from google.genai.errors import APIError # type: ignore + from google import genai + from google.genai import types + from google.genai.errors import APIError except ImportError: raise ImportError( "Google Gen AI native provider not available, to install: `uv add google-genai`" @@ -40,6 +40,7 @@ class GeminiCompletion(BaseLLM): stop_sequences: list[str] | None = None, stream: bool = False, safety_settings: dict[str, Any] | None = None, + client_params: dict[str, Any] | None = None, **kwargs, ): """Initialize Google Gemini chat completion client. @@ -56,35 +57,27 @@ class GeminiCompletion(BaseLLM): stop_sequences: Stop sequences stream: Enable streaming responses safety_settings: Safety filter settings + client_params: Additional parameters to pass to the Google Gen AI Client constructor. + Supports parameters like http_options, credentials, debug_config, etc. **kwargs: Additional parameters """ super().__init__( model=model, temperature=temperature, stop=stop_sequences or [], **kwargs ) - # Get API configuration + # Store client params for later use + self.client_params = client_params or {} + + # Get API configuration with environment variable fallbacks self.api_key = ( api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") ) self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" - # Initialize client based on available configuration - if self.project: - # Use Vertex AI - self.client = genai.Client( - vertexai=True, - project=self.project, - location=self.location, - ) - elif self.api_key: - # Use Gemini Developer API - self.client = genai.Client(api_key=self.api_key) - else: - raise ValueError( - "Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or " - "GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set" - ) + use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" + + self.client = self._initialize_client(use_vertexai) # Store completion parameters self.top_p = top_p @@ -99,6 +92,78 @@ class GeminiCompletion(BaseLLM): self.is_gemini_1_5 = "gemini-1.5" in model.lower() self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 + def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: + """Initialize the Google Gen AI client with proper parameter handling. + + Args: + use_vertexai: Whether to use Vertex AI (from environment variable) + + Returns: + Initialized Google Gen AI Client + """ + client_params = {} + + if self.client_params: + client_params.update(self.client_params) + + if use_vertexai or self.project: + client_params.update( + { + "vertexai": True, + "project": self.project, + "location": self.location, + } + ) + + client_params.pop("api_key", None) + + elif self.api_key: + client_params["api_key"] = self.api_key + + client_params.pop("vertexai", None) + client_params.pop("project", None) + client_params.pop("location", None) + + else: + try: + return genai.Client(**client_params) + except Exception as e: + raise ValueError( + "Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or " + "GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set" + ) from e + + return genai.Client(**client_params) + + def _get_client_params(self) -> dict[str, Any]: + """Get client parameters for compatibility with base class. + + Note: This method is kept for compatibility but the Google Gen AI SDK + uses a different initialization pattern via the Client constructor. + """ + params = {} + + if ( + hasattr(self, "client") + and hasattr(self.client, "vertexai") + and self.client.vertexai + ): + # Vertex AI configuration + params.update( + { + "vertexai": True, + "project": self.project, + "location": self.location, + } + ) + elif self.api_key: + params["api_key"] = self.api_key + + if self.client_params: + params.update(self.client_params) + + return params + def call( self, messages: str | list[dict[str, str]], @@ -427,7 +492,7 @@ class GeminiCompletion(BaseLLM): def supports_stop_words(self) -> bool: """Check if the model supports stop words.""" - return self._supports_stop_words_implementation() + return True def get_context_window_size(self) -> int: """Get the context window size for the model.""" diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py new file mode 100644 index 000000000..ce6219a0c --- /dev/null +++ b/lib/crewai/tests/llms/google/test_google.py @@ -0,0 +1,644 @@ +import os +import sys +import types +from unittest.mock import patch, MagicMock +import pytest + +from crewai.llm import LLM +from crewai.crew import Crew +from crewai.agent import Agent +from crewai.task import Task + + +@pytest.fixture(autouse=True) +def mock_anthropic_api_key(): + """Automatically mock ANTHROPIC_API_KEY for all tests in this module.""" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + yield + + +def test_gemini_completion_is_used_when_google_provider(): + """ + Test that GeminiCompletion from completion.py is used when LLM uses provider 'google' + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + assert llm.__class__.__name__ == "GeminiCompletion" + assert llm.provider == "google" + assert llm.model == "gemini-2.0-flash-001" + + +def test_gemini_completion_is_used_when_gemini_provider(): + """ + Test that GeminiCompletion is used when provider is 'gemini' + """ + llm = LLM(model="gemini/gemini-2.0-flash-001") + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.provider == "gemini" + assert llm.model == "gemini-2.0-flash-001" + + + + +def test_gemini_tool_use_conversation_flow(): + """ + Test that the Gemini completion properly handles tool use conversation flow + """ + from unittest.mock import Mock, patch + from crewai.llms.providers.gemini.completion import GeminiCompletion + + # Create GeminiCompletion instance + completion = GeminiCompletion(model="gemini-2.0-flash-001") + + # Mock tool function + def mock_weather_tool(location: str) -> str: + return f"The weather in {location} is sunny and 75°F" + + available_functions = {"get_weather": mock_weather_tool} + + # Mock the Google Gemini client responses + with patch.object(completion.client.models, 'generate_content') as mock_generate: + # Mock function call in response + mock_function_call = Mock() + mock_function_call.name = "get_weather" + mock_function_call.args = {"location": "San Francisco"} + + mock_part = Mock() + mock_part.function_call = mock_function_call + + mock_content = Mock() + mock_content.parts = [mock_part] + + mock_candidate = Mock() + mock_candidate.content = mock_content + + mock_response = Mock() + mock_response.candidates = [mock_candidate] + mock_response.text = "Based on the weather data, it's a beautiful day in San Francisco with sunny skies and 75°F temperature." + mock_response.usage_metadata = Mock() + mock_response.usage_metadata.prompt_token_count = 100 + mock_response.usage_metadata.candidates_token_count = 50 + mock_response.usage_metadata.total_token_count = 150 + + mock_generate.return_value = mock_response + + # Test the call + messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] + result = completion.call( + messages=messages, + available_functions=available_functions + ) + + # Verify the tool was executed and returned the result + assert result == "The weather in San Francisco is sunny and 75°F" + + # Verify that the API was called + assert mock_generate.called + + +def test_gemini_completion_module_is_imported(): + """ + Test that the completion module is properly imported when using Google provider + """ + module_name = "crewai.llms.providers.gemini.completion" + + # Remove module from cache if it exists + if module_name in sys.modules: + del sys.modules[module_name] + + # Create LLM instance - this should trigger the import + LLM(model="google/gemini-2.0-flash-001") + + # Verify the module was imported + assert module_name in sys.modules + completion_mod = sys.modules[module_name] + assert isinstance(completion_mod, types.ModuleType) + + # Verify the class exists in the module + assert hasattr(completion_mod, 'GeminiCompletion') + + +def test_fallback_to_litellm_when_native_gemini_fails(): + """ + Test that LLM falls back to LiteLLM when native Gemini completion fails + """ + # Mock the _get_native_provider to return a failing class + with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider: + + class FailingCompletion: + def __init__(self, *args, **kwargs): + raise Exception("Native Google Gen AI SDK failed") + + mock_get_provider.return_value = FailingCompletion + + # This should fall back to LiteLLM + llm = LLM(model="google/gemini-2.0-flash-001") + + # Check that it's using LiteLLM + assert hasattr(llm, 'is_litellm') + assert llm.is_litellm == True + + +def test_gemini_completion_initialization_parameters(): + """ + Test that GeminiCompletion is initialized with correct parameters + """ + llm = LLM( + model="google/gemini-2.0-flash-001", + temperature=0.7, + max_output_tokens=2000, + top_p=0.9, + top_k=40, + api_key="test-key" + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.model == "gemini-2.0-flash-001" + assert llm.temperature == 0.7 + assert llm.max_output_tokens == 2000 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + + +def test_gemini_specific_parameters(): + """ + Test Gemini-specific parameters like stop_sequences, streaming, and safety settings + """ + safety_settings = { + "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE" + } + + llm = LLM( + model="google/gemini-2.0-flash-001", + stop_sequences=["Human:", "Assistant:"], + stream=True, + safety_settings=safety_settings, + project="test-project", + location="us-central1" + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stream == True + assert llm.safety_settings == safety_settings + assert llm.project == "test-project" + assert llm.location == "us-central1" + + +def test_gemini_completion_call(): + """ + Test that GeminiCompletion call method works + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the call method on the instance + with patch.object(llm, 'call', return_value="Hello! I'm Gemini, ready to help.") as mock_call: + result = llm.call("Hello, how are you?") + + assert result == "Hello! I'm Gemini, ready to help." + mock_call.assert_called_once_with("Hello, how are you?") + + +def test_gemini_completion_called_during_crew_execution(): + """ + Test that GeminiCompletion.call is actually invoked when running a crew + """ + # Create the LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the call method on the specific instance + with patch.object(gemini_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call: + + # Create agent with explicit LLM configuration + agent = Agent( + role="Research Assistant", + goal="Find population info", + backstory="You research populations.", + llm=gemini_llm, + ) + + task = Task( + description="Find Tokyo population", + expected_output="Population number", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + result = crew.kickoff() + + # Verify mock was called + assert mock_call.called + assert "14 million" in str(result) + + +def test_gemini_completion_call_arguments(): + """ + Test that GeminiCompletion.call is invoked with correct arguments + """ + # Create LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the instance method + with patch.object(gemini_llm, 'call') as mock_call: + mock_call.return_value = "Task completed successfully." + + agent = Agent( + role="Test Agent", + goal="Complete a simple task", + backstory="You are a test agent.", + llm=gemini_llm # Use same instance + ) + + task = Task( + description="Say hello world", + expected_output="Hello world", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + # Verify call was made + assert mock_call.called + + # Check the arguments passed to the call method + call_args = mock_call.call_args + assert call_args is not None + + # The first argument should be the messages + messages = call_args[0][0] # First positional argument + assert isinstance(messages, (str, list)) + + # Verify that the task description appears in the messages + if isinstance(messages, str): + assert "hello world" in messages.lower() + elif isinstance(messages, list): + message_content = str(messages).lower() + assert "hello world" in message_content + + +def test_multiple_gemini_calls_in_crew(): + """ + Test that GeminiCompletion.call is invoked multiple times for multiple tasks + """ + # Create LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the instance method + with patch.object(gemini_llm, 'call') as mock_call: + mock_call.return_value = "Task completed." + + agent = Agent( + role="Multi-task Agent", + goal="Complete multiple tasks", + backstory="You can handle multiple tasks.", + llm=gemini_llm # Use same instance + ) + + task1 = Task( + description="First task", + expected_output="First result", + agent=agent, + ) + + task2 = Task( + description="Second task", + expected_output="Second result", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task1, task2] + ) + crew.kickoff() + + # Verify multiple calls were made + assert mock_call.call_count >= 2 # At least one call per task + + # Verify each call had proper arguments + for call in mock_call.call_args_list: + assert len(call[0]) > 0 # Has positional arguments + messages = call[0][0] + assert messages is not None + + +def test_gemini_completion_with_tools(): + """ + Test that GeminiCompletion.call is invoked with tools when agent has tools + """ + from crewai.tools import tool + + @tool + def sample_tool(query: str) -> str: + """A sample tool for testing""" + return f"Tool result for: {query}" + + # Create LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the instance method + with patch.object(gemini_llm, 'call') as mock_call: + mock_call.return_value = "Task completed with tools." + + agent = Agent( + role="Tool User", + goal="Use tools to complete tasks", + backstory="You can use tools.", + llm=gemini_llm, # Use same instance + tools=[sample_tool] + ) + + task = Task( + description="Use the sample tool", + expected_output="Tool usage result", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + assert mock_call.called + + call_args = mock_call.call_args + call_kwargs = call_args[1] if len(call_args) > 1 else {} + + if 'tools' in call_kwargs: + assert call_kwargs['tools'] is not None + assert len(call_kwargs['tools']) > 0 + + +def test_gemini_raises_error_when_model_not_supported(): + """Test that GeminiCompletion raises ValueError when model not supported""" + + # Mock the Google client to raise an error + with patch('crewai.llms.providers.gemini.completion.genai') as mock_genai: + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + + # Mock the error that Google would raise for unsupported models + from google.genai.errors import ClientError # type: ignore + mock_client.models.generate_content.side_effect = ClientError( + code=404, + response_json={ + 'error': { + 'code': 404, + 'message': 'models/model-doesnt-exist is not found for API version v1beta, or is not supported for generateContent.', + 'status': 'NOT_FOUND' + } + } + ) + + llm = LLM(model="google/model-doesnt-exist") + + with pytest.raises(Exception): # Should raise some error for unsupported model + llm.call("Hello") + + +def test_gemini_vertex_ai_setup(): + """ + Test that Vertex AI configuration is properly handled + """ + with patch.dict(os.environ, { + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "us-west1" + }): + llm = LLM( + model="google/gemini-2.0-flash-001", + project="test-project", + location="us-west1" + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + + assert llm.project == "test-project" + assert llm.location == "us-west1" + + +def test_gemini_api_key_configuration(): + """ + Test that API key configuration works for both GOOGLE_API_KEY and GEMINI_API_KEY + """ + # Test with GOOGLE_API_KEY + with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): + llm = LLM(model="google/gemini-2.0-flash-001") + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.api_key == "test-google-key" + + # Test with GEMINI_API_KEY + with patch.dict(os.environ, {"GEMINI_API_KEY": "test-gemini-key"}, clear=True): + llm = LLM(model="google/gemini-2.0-flash-001") + + assert isinstance(llm, GeminiCompletion) + assert llm.api_key == "test-gemini-key" + + +def test_gemini_model_capabilities(): + """ + Test that model capabilities are correctly identified + """ + # Test Gemini 2.0 model + llm_2_0 = LLM(model="google/gemini-2.0-flash-001") + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm_2_0, GeminiCompletion) + assert llm_2_0.is_gemini_2 == True + assert llm_2_0.supports_tools == True + + # Test Gemini 1.5 model + llm_1_5 = LLM(model="google/gemini-1.5-pro") + assert isinstance(llm_1_5, GeminiCompletion) + assert llm_1_5.is_gemini_1_5 == True + assert llm_1_5.supports_tools == True + + +def test_gemini_generation_config(): + """ + Test that generation config is properly prepared + """ + llm = LLM( + model="google/gemini-2.0-flash-001", + temperature=0.7, + top_p=0.9, + top_k=40, + max_output_tokens=1000 + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + + # Test config preparation + config = llm._prepare_generation_config() + + # Verify config has the expected parameters + assert hasattr(config, 'temperature') or 'temperature' in str(config) + assert hasattr(config, 'top_p') or 'top_p' in str(config) + assert hasattr(config, 'top_k') or 'top_k' in str(config) + assert hasattr(config, 'max_output_tokens') or 'max_output_tokens' in str(config) + + +def test_gemini_model_detection(): + """ + Test that various Gemini model formats are properly detected + """ + # Test Gemini model naming patterns that actually work with provider detection + gemini_test_cases = [ + "google/gemini-2.0-flash-001", + "gemini/gemini-2.0-flash-001", + "google/gemini-1.5-pro", + "gemini/gemini-1.5-flash" + ] + + for model_name in gemini_test_cases: + llm = LLM(model=model_name) + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion), f"Failed for model: {model_name}" + + +def test_gemini_supports_stop_words(): + """ + Test that Gemini models support stop sequences + """ + llm = LLM(model="google/gemini-2.0-flash-001") + assert llm.supports_stop_words() == True + + +def test_gemini_context_window_size(): + """ + Test that Gemini models return correct context window sizes + """ + # Test Gemini 2.0 Flash + llm_2_0 = LLM(model="google/gemini-2.0-flash-001") + context_size_2_0 = llm_2_0.get_context_window_size() + assert context_size_2_0 > 500000 # Should be substantial (1M tokens) + + # Test Gemini 1.5 Pro + llm_1_5 = LLM(model="google/gemini-1.5-pro") + context_size_1_5 = llm_1_5.get_context_window_size() + assert context_size_1_5 > 1000000 # Should be very large (2M tokens) + + +def test_gemini_message_formatting(): + """ + Test that messages are properly formatted for Gemini API + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Test message formatting + test_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + + formatted_contents, system_instruction = llm._format_messages_for_gemini(test_messages) + + # System message should be extracted + assert system_instruction == "You are a helpful assistant." + + # Remaining messages should be Content objects + assert len(formatted_contents) >= 3 # Should have user, model, user messages + + # First content should be user role + assert formatted_contents[0].role == "user" + # Second should be model (converted from assistant) + assert formatted_contents[1].role == "model" + + +def test_gemini_streaming_parameter(): + """ + Test that streaming parameter is properly handled + """ + # Test non-streaming + llm_no_stream = LLM(model="google/gemini-2.0-flash-001", stream=False) + assert llm_no_stream.stream == False + + # Test streaming + llm_stream = LLM(model="google/gemini-2.0-flash-001", stream=True) + assert llm_stream.stream == True + + +def test_gemini_tool_conversion(): + """ + Test that tools are properly converted to Gemini format + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock tool in CrewAI format + crewai_tools = [{ + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + } + }] + + # Test tool conversion + gemini_tools = llm._convert_tools_for_interference(crewai_tools) + + assert len(gemini_tools) == 1 + # Gemini tools are Tool objects with function_declarations + assert hasattr(gemini_tools[0], 'function_declarations') + assert len(gemini_tools[0].function_declarations) == 1 + + func_decl = gemini_tools[0].function_declarations[0] + assert func_decl.name == "test_tool" + assert func_decl.description == "A test tool" + + +def test_gemini_environment_variable_api_key(): + """ + Test that Google API key is properly loaded from environment + """ + with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): + llm = LLM(model="google/gemini-2.0-flash-001") + + assert llm.client is not None + assert hasattr(llm.client, 'models') + assert llm.api_key == "test-google-key" + + +def test_gemini_token_usage_tracking(): + """ + Test that token usage is properly tracked for Gemini responses + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the Gemini response with usage information + with patch.object(llm.client.models, 'generate_content') as mock_generate: + mock_response = MagicMock() + mock_response.text = "test response" + mock_response.candidates = [] + mock_response.usage_metadata = MagicMock( + prompt_token_count=50, + candidates_token_count=25, + total_token_count=75 + ) + mock_generate.return_value = mock_response + + result = llm.call("Hello") + + # Verify the response + assert result == "test response" + + # Verify token usage was extracted + usage = llm._extract_token_usage(mock_response) + assert usage["prompt_token_count"] == 50 + assert usage["candidates_token_count"] == 25 + assert usage["total_token_count"] == 75 + assert usage["total_tokens"] == 75 From dcd57ccc9f7d8004a99af04600cec45300be6755 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Tue, 14 Oct 2025 15:34:28 -0700 Subject: [PATCH 05/11] feat: enhance AnthropicCompletion class with additional client parameters and tool handling - Added support for client_params in the AnthropicCompletion class to allow for additional client configuration. - Refactored client initialization to use a dedicated method for retrieving client parameters. - Implemented a new method to handle tool use conversation flow, ensuring proper execution and response handling. - Introduced comprehensive test cases to validate the functionality of the AnthropicCompletion class, including tool use scenarios and parameter handling. --- .../llms/providers/anthropic/completion.py | 257 +++++-- .../tests/llms/anthropic/test_anthropic.py | 660 ++++++++++++++++++ 2 files changed, 878 insertions(+), 39 deletions(-) create mode 100644 lib/crewai/tests/llms/anthropic/test_anthropic.py diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 691490dd2..a90f06573 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -40,6 +40,7 @@ class AnthropicCompletion(BaseLLM): top_p: float | None = None, stop_sequences: list[str] | None = None, stream: bool = False, + client_params: dict[str, Any] | None = None, **kwargs, ): """Initialize Anthropic chat completion client. @@ -55,19 +56,20 @@ class AnthropicCompletion(BaseLLM): top_p: Nucleus sampling parameter stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) stream: Enable streaming responses + client_params: Additional parameters for the Anthropic client **kwargs: Additional parameters """ super().__init__( model=model, temperature=temperature, stop=stop_sequences or [], **kwargs ) - # Initialize Anthropic client - self.client = Anthropic( - api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), - base_url=base_url, - timeout=timeout, - max_retries=max_retries, - ) + # Client params + self.client_params = client_params + self.base_url = base_url + self.timeout = timeout + self.max_retries = max_retries + + self.client = Anthropic(**self._get_client_params()) # Store completion parameters self.max_tokens = max_tokens @@ -79,6 +81,26 @@ class AnthropicCompletion(BaseLLM): self.is_claude_3 = "claude-3" in model.lower() self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use + def _get_client_params(self) -> dict[str, Any]: + """Get client parameters.""" + + if self.api_key is None: + self.api_key = os.getenv("ANTHROPIC_API_KEY") + if self.api_key is None: + raise ValueError("ANTHROPIC_API_KEY is required") + + client_params = { + "api_key": self.api_key, + "base_url": self.base_url, + "timeout": self.timeout, + "max_retries": self.max_retries, + } + + if self.client_params: + client_params.update(self.client_params) + + return client_params + def call( self, messages: str | list[dict[str, str]], @@ -102,6 +124,7 @@ class AnthropicCompletion(BaseLLM): Chat completion response or tool call result """ try: + print("we are calling", messages) # Emit call started event self._emit_call_started_event( messages=messages, @@ -121,6 +144,7 @@ class AnthropicCompletion(BaseLLM): completion_params = self._prepare_completion_params( formatted_messages, system_message, tools ) + print("completion_params", completion_params) # Handle streaming vs non-streaming if self.stream: @@ -183,12 +207,25 @@ class AnthropicCompletion(BaseLLM): def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]: """Convert CrewAI tool format to Anthropic tool use format.""" - from crewai.llms.providers.utils.common import safe_tool_conversion - anthropic_tools = [] for tool in tools: - name, description, parameters = safe_tool_conversion(tool, "Anthropic") + if "input_schema" in tool and "name" in tool and "description" in tool: + anthropic_tools.append(tool) + continue + + try: + from crewai.llms.providers.utils.common import safe_tool_conversion + + name, description, parameters = safe_tool_conversion(tool, "Anthropic") + except (ImportError, Exception): + name = tool.get("name", "unknown_tool") + description = tool.get("description", "A tool function") + parameters = ( + tool.get("input_schema") + or tool.get("parameters") + or tool.get("schema") + ) anthropic_tool = { "name": name, @@ -196,7 +233,13 @@ class AnthropicCompletion(BaseLLM): } if parameters and isinstance(parameters, dict): - anthropic_tool["input_schema"] = parameters # type: ignore + anthropic_tool["input_schema"] = parameters + else: + anthropic_tool["input_schema"] = { + "type": "object", + "properties": {}, + "required": [], + } anthropic_tools.append(anthropic_tool) @@ -229,13 +272,11 @@ class AnthropicCompletion(BaseLLM): content = message.get("content", "") if role == "system": - # Extract system message - Anthropic handles it separately if system_message: system_message += f"\n\n{content}" else: system_message = content else: - # Add user/assistant messages - ensure both role and content are str, not None role_str = role if role is not None else "user" content_str = content if content is not None else "" formatted_messages.append({"role": role_str, "content": content_str}) @@ -259,6 +300,7 @@ class AnthropicCompletion(BaseLLM): ) -> str | Any: """Handle non-streaming message completion.""" try: + print("params", params) response: Message = self.client.messages.create(**params) except Exception as e: @@ -270,22 +312,22 @@ class AnthropicCompletion(BaseLLM): usage = self._extract_anthropic_token_usage(response) self._track_token_usage_internal(usage) + # Check if Claude wants to use tools if response.content and available_functions: - for content_block in response.content: - if isinstance(content_block, ToolUseBlock): - function_name = content_block.name - function_args = content_block.input + tool_uses = [ + block for block in response.content if isinstance(block, ToolUseBlock) + ] - result = self._handle_tool_execution( - function_name=function_name, - function_args=function_args, # type: ignore - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - ) - - if result is not None: - return result + if tool_uses: + # Handle tool use conversation flow + return self._handle_tool_use_conversation( + response, + tool_uses, + params, + available_functions, + from_task, + from_agent, + ) # Extract text content content = "" @@ -350,26 +392,54 @@ class AnthropicCompletion(BaseLLM): # Handle completed tool uses if tool_uses and available_functions: - for tool_data in tool_uses.values(): - function_name = tool_data["name"] - + # Convert streamed tool uses to ToolUseBlock-like objects for consistency + tool_use_blocks = [] + for tool_id, tool_data in tool_uses.items(): try: function_args = json.loads(tool_data["input"]) except json.JSONDecodeError as e: logging.error(f"Failed to parse streamed tool arguments: {e}") continue - # Execute tool - result = self._handle_tool_execution( - function_name=function_name, - function_args=function_args, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, + # Create a mock ToolUseBlock-like object + class MockToolUse: + def __init__(self, tool_id: str, name: str, input_args: dict): + self.id = tool_id + self.name = name + self.input = input_args + + tool_use_blocks.append( + MockToolUse(tool_id, tool_data["name"], function_args) ) - if result is not None: - return result + if tool_use_blocks: + # Create a mock response object for the tool conversation flow + class MockResponse: + def __init__(self, content_blocks): + self.content = content_blocks + + # Combine text content and tool uses in the response + response_content = [] + if full_response.strip(): # Add text content if any + + class MockTextBlock: + def __init__(self, text: str): + self.text = text + + response_content.append(MockTextBlock(full_response)) + + response_content.extend(tool_use_blocks) + mock_response = MockResponse(response_content) + + # Handle tool use conversation flow + return self._handle_tool_use_conversation( + mock_response, + tool_use_blocks, + params, + available_functions, + from_task, + from_agent, + ) # Apply stop words to full response full_response = self._apply_stop_words(full_response) @@ -385,6 +455,115 @@ class AnthropicCompletion(BaseLLM): return full_response + def _handle_tool_use_conversation( + self, + initial_response: Message + | Any, # Can be Message or mock response from streaming + tool_uses: list[ToolUseBlock] + | list[Any], # Can be ToolUseBlock or mock objects + params: dict[str, Any], + available_functions: dict[str, Any], + from_task: Any | None = None, + from_agent: Any | None = None, + ) -> str: + """Handle the complete tool use conversation flow. + + This implements the proper Anthropic tool use pattern: + 1. Claude requests tool use + 2. We execute the tools + 3. We send tool results back to Claude + 4. Claude processes results and generates final response + """ + # Execute all requested tools and collect results + tool_results = [] + + for tool_use in tool_uses: + function_name = tool_use.name + function_args = tool_use.input + + # Execute the tool + result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, # type: ignore + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + # Create tool result in Anthropic format + tool_result = { + "type": "tool_result", + "tool_use_id": tool_use.id, + "content": str(result) + if result is not None + else "Tool execution completed", + } + tool_results.append(tool_result) + + # Prepare follow-up conversation with tool results + follow_up_params = params.copy() + + # Add Claude's tool use response to conversation + assistant_message = {"role": "assistant", "content": initial_response.content} + + # Add user message with tool results + user_message = {"role": "user", "content": tool_results} + + # Update messages for follow-up call + follow_up_params["messages"] = params["messages"] + [ + assistant_message, + user_message, + ] + + try: + # Send tool results back to Claude for final response + final_response: Message = self.client.messages.create(**follow_up_params) + + # Track token usage for follow-up call + follow_up_usage = self._extract_anthropic_token_usage(final_response) + self._track_token_usage_internal(follow_up_usage) + + # Extract final text content + final_content = "" + if final_response.content: + for content_block in final_response.content: + if hasattr(content_block, "text"): + final_content += content_block.text + + final_content = self._apply_stop_words(final_content) + + # Emit completion event for the final response + self._emit_call_completed_event( + response=final_content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=follow_up_params["messages"], + ) + + # Log combined token usage + total_usage = { + "input_tokens": follow_up_usage.get("input_tokens", 0), + "output_tokens": follow_up_usage.get("output_tokens", 0), + "total_tokens": follow_up_usage.get("total_tokens", 0), + } + + if total_usage.get("total_tokens", 0) > 0: + logging.info(f"Anthropic API tool conversation usage: {total_usage}") + + return final_content + + except Exception as e: + if is_context_length_exceeded(e): + logging.error(f"Context window exceeded in tool follow-up: {e}") + raise LLMContextLengthExceededError(str(e)) from e + + logging.error(f"Tool follow-up conversation failed: {e}") + # Fallback: return the first tool result if follow-up fails + if tool_results: + return tool_results[0]["content"] + raise e + def supports_function_calling(self) -> bool: """Check if the model supports function calling.""" return self.supports_tools diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py new file mode 100644 index 000000000..7d0780561 --- /dev/null +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -0,0 +1,660 @@ +import os +import sys +import types +from unittest.mock import patch, MagicMock +import pytest + +from crewai.llm import LLM +from crewai.llms.providers.anthropic.completion import AnthropicCompletion +from crewai.crew import Crew +from crewai.agent import Agent +from crewai.task import Task +from crewai.cli.constants import DEFAULT_LLM_MODEL + + +def test_anthropic_completion_is_used_when_anthropic_provider(): + """ + Test that AnthropicCompletion from completion.py is used when LLM uses provider 'anthropic' + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + assert llm.__class__.__name__ == "AnthropicCompletion" + assert llm.provider == "anthropic" + assert llm.model == "claude-3-5-sonnet-20241022" + + +def test_anthropic_completion_is_used_when_claude_provider(): + """ + Test that AnthropicCompletion is used when provider is 'claude' + """ + llm = LLM(model="claude/claude-3-5-sonnet-20241022") + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + assert llm.provider == "claude" + assert llm.model == "claude-3-5-sonnet-20241022" + + + + +def test_anthropic_tool_use_conversation_flow(): + """ + Test that the Anthropic completion properly handles tool use conversation flow + """ + from unittest.mock import Mock, patch + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + from anthropic.types.tool_use_block import ToolUseBlock + + # Create AnthropicCompletion instance + completion = AnthropicCompletion(model="claude-3-5-sonnet-20241022") + + # Mock tool function + def mock_weather_tool(location: str) -> str: + return f"The weather in {location} is sunny and 75°F" + + available_functions = {"get_weather": mock_weather_tool} + + # Mock the Anthropic client responses + with patch.object(completion.client.messages, 'create') as mock_create: + # Mock initial response with tool use - need to properly mock ToolUseBlock + mock_tool_use = Mock(spec=ToolUseBlock) + mock_tool_use.id = "tool_123" + mock_tool_use.name = "get_weather" + mock_tool_use.input = {"location": "San Francisco"} + + mock_initial_response = Mock() + mock_initial_response.content = [mock_tool_use] + mock_initial_response.usage = Mock() + mock_initial_response.usage.input_tokens = 100 + mock_initial_response.usage.output_tokens = 50 + + # Mock final response after tool result - properly mock text content + mock_text_block = Mock() + # Set the text attribute as a string, not another Mock + mock_text_block.configure_mock(text="Based on the weather data, it's a beautiful day in San Francisco with sunny skies and 75°F temperature.") + + mock_final_response = Mock() + mock_final_response.content = [mock_text_block] + mock_final_response.usage = Mock() + mock_final_response.usage.input_tokens = 150 + mock_final_response.usage.output_tokens = 75 + + # Configure mock to return different responses on successive calls + mock_create.side_effect = [mock_initial_response, mock_final_response] + + # Test the call + messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] + result = completion.call( + messages=messages, + available_functions=available_functions + ) + + # Verify the result contains the final response + assert "beautiful day in San Francisco" in result + assert "sunny skies" in result + assert "75°F" in result + + # Verify that two API calls were made (initial + follow-up) + assert mock_create.call_count == 2 + + # Verify the second call includes tool results + second_call_args = mock_create.call_args_list[1][1] # kwargs of second call + messages_in_second_call = second_call_args["messages"] + + # Should have original user message + assistant tool use + user tool result + assert len(messages_in_second_call) == 3 + assert messages_in_second_call[0]["role"] == "user" + assert messages_in_second_call[1]["role"] == "assistant" + assert messages_in_second_call[2]["role"] == "user" + + # Verify tool result format + tool_result = messages_in_second_call[2]["content"][0] + assert tool_result["type"] == "tool_result" + assert tool_result["tool_use_id"] == "tool_123" + assert "sunny and 75°F" in tool_result["content"] + + +def test_anthropic_completion_module_is_imported(): + """ + Test that the completion module is properly imported when using Anthropic provider + """ + module_name = "crewai.llms.providers.anthropic.completion" + + # Remove module from cache if it exists + if module_name in sys.modules: + del sys.modules[module_name] + + # Create LLM instance - this should trigger the import + LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Verify the module was imported + assert module_name in sys.modules + completion_mod = sys.modules[module_name] + assert isinstance(completion_mod, types.ModuleType) + + # Verify the class exists in the module + assert hasattr(completion_mod, 'AnthropicCompletion') + + +def test_fallback_to_litellm_when_native_anthropic_fails(): + """ + Test that LLM falls back to LiteLLM when native Anthropic completion fails + """ + # Mock the _get_native_provider to return a failing class + with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider: + + class FailingCompletion: + def __init__(self, *args, **kwargs): + raise Exception("Native Anthropic SDK failed") + + mock_get_provider.return_value = FailingCompletion + + # This should fall back to LiteLLM + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Check that it's using LiteLLM + assert hasattr(llm, 'is_litellm') + assert llm.is_litellm == True + + +def test_anthropic_completion_initialization_parameters(): + """ + Test that AnthropicCompletion is initialized with correct parameters + """ + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + temperature=0.7, + max_tokens=2000, + top_p=0.9, + api_key="test-key" + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + assert llm.model == "claude-3-5-sonnet-20241022" + assert llm.temperature == 0.7 + assert llm.max_tokens == 2000 + assert llm.top_p == 0.9 + + +def test_anthropic_specific_parameters(): + """ + Test Anthropic-specific parameters like stop_sequences and streaming + """ + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + stop_sequences=["Human:", "Assistant:"], + stream=True, + max_retries=5, + timeout=60 + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stream == True + assert llm.client.max_retries == 5 + assert llm.client.timeout == 60 + + +def test_anthropic_completion_call(): + """ + Test that AnthropicCompletion call method works + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the call method on the instance + with patch.object(llm, 'call', return_value="Hello! I'm Claude, ready to help.") as mock_call: + result = llm.call("Hello, how are you?") + + assert result == "Hello! I'm Claude, ready to help." + mock_call.assert_called_once_with("Hello, how are you?") + + +def test_anthropic_completion_called_during_crew_execution(): + """ + Test that AnthropicCompletion.call is actually invoked when running a crew + """ + # Create the LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the call method on the specific instance + with patch.object(anthropic_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call: + + # Create agent with explicit LLM configuration + agent = Agent( + role="Research Assistant", + goal="Find population info", + backstory="You research populations.", + llm=anthropic_llm, + ) + + task = Task( + description="Find Tokyo population", + expected_output="Population number", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + result = crew.kickoff() + + # Verify mock was called + assert mock_call.called + assert "14 million" in str(result) + + +def test_anthropic_completion_call_arguments(): + """ + Test that AnthropicCompletion.call is invoked with correct arguments + """ + # Create LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the instance method + with patch.object(anthropic_llm, 'call') as mock_call: + mock_call.return_value = "Task completed successfully." + + agent = Agent( + role="Test Agent", + goal="Complete a simple task", + backstory="You are a test agent.", + llm=anthropic_llm # Use same instance + ) + + task = Task( + description="Say hello world", + expected_output="Hello world", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + # Verify call was made + assert mock_call.called + + # Check the arguments passed to the call method + call_args = mock_call.call_args + assert call_args is not None + + # The first argument should be the messages + messages = call_args[0][0] # First positional argument + assert isinstance(messages, (str, list)) + + # Verify that the task description appears in the messages + if isinstance(messages, str): + assert "hello world" in messages.lower() + elif isinstance(messages, list): + message_content = str(messages).lower() + assert "hello world" in message_content + + +def test_multiple_anthropic_calls_in_crew(): + """ + Test that AnthropicCompletion.call is invoked multiple times for multiple tasks + """ + # Create LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the instance method + with patch.object(anthropic_llm, 'call') as mock_call: + mock_call.return_value = "Task completed." + + agent = Agent( + role="Multi-task Agent", + goal="Complete multiple tasks", + backstory="You can handle multiple tasks.", + llm=anthropic_llm # Use same instance + ) + + task1 = Task( + description="First task", + expected_output="First result", + agent=agent, + ) + + task2 = Task( + description="Second task", + expected_output="Second result", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task1, task2] + ) + crew.kickoff() + + # Verify multiple calls were made + assert mock_call.call_count >= 2 # At least one call per task + + # Verify each call had proper arguments + for call in mock_call.call_args_list: + assert len(call[0]) > 0 # Has positional arguments + messages = call[0][0] + assert messages is not None + + +def test_anthropic_completion_with_tools(): + """ + Test that AnthropicCompletion.call is invoked with tools when agent has tools + """ + from crewai.tools import tool + + @tool + def sample_tool(query: str) -> str: + """A sample tool for testing""" + return f"Tool result for: {query}" + + # Create LLM instance first + anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the instance method + with patch.object(anthropic_llm, 'call') as mock_call: + mock_call.return_value = "Task completed with tools." + + agent = Agent( + role="Tool User", + goal="Use tools to complete tasks", + backstory="You can use tools.", + llm=anthropic_llm, # Use same instance + tools=[sample_tool] + ) + + task = Task( + description="Use the sample tool", + expected_output="Tool usage result", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + assert mock_call.called + + call_args = mock_call.call_args + call_kwargs = call_args[1] if len(call_args) > 1 else {} + + if 'tools' in call_kwargs: + assert call_kwargs['tools'] is not None + assert len(call_kwargs['tools']) > 0 + + +def test_anthropic_raises_error_when_model_not_supported(): + """Test that AnthropicCompletion raises ValueError when model not supported""" + + # Mock the Anthropic client to raise an error + with patch('crewai.llms.providers.anthropic.completion.Anthropic') as mock_anthropic_class: + mock_client = MagicMock() + mock_anthropic_class.return_value = mock_client + + # Mock the error that Anthropic would raise for unsupported models + from anthropic import NotFoundError + mock_client.messages.create.side_effect = NotFoundError( + message="The model `model-doesnt-exist` does not exist", + response=MagicMock(), + body={} + ) + + llm = LLM(model="anthropic/model-doesnt-exist") + + with pytest.raises(Exception): # Should raise some error for unsupported model + llm.call("Hello") + + +def test_anthropic_client_params_setup(): + """ + Test that client_params are properly merged with default client parameters + """ + # Use only valid Anthropic client parameters + custom_client_params = { + "default_headers": {"X-Custom-Header": "test-value"}, + } + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + base_url="https://custom-api.com", + timeout=45, + max_retries=5, + client_params=custom_client_params + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + assert llm.client_params == custom_client_params + + merged_params = llm._get_client_params() + + assert merged_params["api_key"] == "test-key" + assert merged_params["base_url"] == "https://custom-api.com" + assert merged_params["timeout"] == 45 + assert merged_params["max_retries"] == 5 + + assert merged_params["default_headers"] == {"X-Custom-Header": "test-value"} + + +def test_anthropic_client_params_override_defaults(): + """ + Test that client_params can override default client parameters + """ + override_client_params = { + "timeout": 120, # Override the timeout parameter + "max_retries": 10, # Override the max_retries parameter + "default_headers": {"X-Override": "true"} # Valid custom parameter + } + + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + timeout=30, + max_retries=3, + client_params=override_client_params + ) + + # Verify this is actually AnthropicCompletion, not LiteLLM fallback + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + merged_params = llm._get_client_params() + + # client_params should override the individual parameters + assert merged_params["timeout"] == 120 + assert merged_params["max_retries"] == 10 + assert merged_params["default_headers"] == {"X-Override": "true"} + + +def test_anthropic_client_params_none(): + """ + Test that client_params=None works correctly (no additional parameters) + """ + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + base_url="https://api.anthropic.com", + timeout=60, + max_retries=2, + client_params=None + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + assert llm.client_params is None + + merged_params = llm._get_client_params() + + expected_keys = {"api_key", "base_url", "timeout", "max_retries"} + assert set(merged_params.keys()) == expected_keys + + # Fixed assertions - all should be inside the with block and use correct values + assert merged_params["api_key"] == "test-key" # Not "test-anthropic-key" + assert merged_params["base_url"] == "https://api.anthropic.com" + assert merged_params["timeout"] == 60 + assert merged_params["max_retries"] == 2 + + +def test_anthropic_client_params_empty_dict(): + """ + Test that client_params={} works correctly (empty additional parameters) + """ + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + llm = LLM( + model="anthropic/claude-3-5-sonnet-20241022", + api_key="test-key", + client_params={} + ) + + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion) + + assert llm.client_params == {} + + merged_params = llm._get_client_params() + + assert "api_key" in merged_params + assert merged_params["api_key"] == "test-key" + + +def test_anthropic_model_detection(): + """ + Test that various Anthropic model formats are properly detected + """ + # Test Anthropic model naming patterns that actually work with provider detection + anthropic_test_cases = [ + "anthropic/claude-3-5-sonnet-20241022", + "claude/claude-3-5-sonnet-20241022" + ] + + for model_name in anthropic_test_cases: + llm = LLM(model=model_name) + from crewai.llms.providers.anthropic.completion import AnthropicCompletion + assert isinstance(llm, AnthropicCompletion), f"Failed for model: {model_name}" + + +def test_anthropic_supports_stop_words(): + """ + Test that Anthropic models support stop sequences + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + assert llm.supports_stop_words() == True + + +def test_anthropic_context_window_size(): + """ + Test that Anthropic models return correct context window sizes + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + context_size = llm.get_context_window_size() + + # Should return a reasonable context window size (Claude 3.5 has 200k tokens) + assert context_size > 100000 # Should be substantial + assert context_size <= 200000 # But not exceed the actual limit + + +def test_anthropic_message_formatting(): + """ + Test that messages are properly formatted for Anthropic API + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Test message formatting + test_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + + formatted_messages, system_message = llm._format_messages_for_anthropic(test_messages) + + # System message should be extracted + assert system_message == "You are a helpful assistant." + + # Remaining messages should start with user + assert formatted_messages[0]["role"] == "user" + assert len(formatted_messages) >= 3 # Should have user, assistant, user messages + + +def test_anthropic_streaming_parameter(): + """ + Test that streaming parameter is properly handled + """ + # Test non-streaming + llm_no_stream = LLM(model="anthropic/claude-3-5-sonnet-20241022", stream=False) + assert llm_no_stream.stream == False + + # Test streaming + llm_stream = LLM(model="anthropic/claude-3-5-sonnet-20241022", stream=True) + assert llm_stream.stream == True + + +def test_anthropic_tool_conversion(): + """ + Test that tools are properly converted to Anthropic format + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock tool in CrewAI format + crewai_tools = [{ + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + } + }] + + # Test tool conversion + anthropic_tools = llm._convert_tools_for_interference(crewai_tools) + + assert len(anthropic_tools) == 1 + assert anthropic_tools[0]["name"] == "test_tool" + assert anthropic_tools[0]["description"] == "A test tool" + assert "input_schema" in anthropic_tools[0] + + +def test_anthropic_environment_variable_api_key(): + """ + Test that Anthropic API key is properly loaded from environment + """ + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}): + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + assert llm.client is not None + assert hasattr(llm.client, 'messages') + + +def test_anthropic_token_usage_tracking(): + """ + Test that token usage is properly tracked for Anthropic responses + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + # Mock the Anthropic response with usage information + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=50, output_tokens=25) + mock_create.return_value = mock_response + + result = llm.call("Hello") + + # Verify the response + assert result == "test response" + + # Verify token usage was extracted + usage = llm._extract_anthropic_token_usage(mock_response) + assert usage["input_tokens"] == 50 + assert usage["output_tokens"] == 25 + assert usage["total_tokens"] == 75 From 0073b4206f5ef145658e5681e3159fee9a0d7fa5 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Tue, 14 Oct 2025 15:36:30 -0700 Subject: [PATCH 06/11] drop print statements --- lib/crewai/src/crewai/llms/providers/anthropic/completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index a90f06573..ffcaf3077 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -124,7 +124,6 @@ class AnthropicCompletion(BaseLLM): Chat completion response or tool call result """ try: - print("we are calling", messages) # Emit call started event self._emit_call_started_event( messages=messages, @@ -144,7 +143,6 @@ class AnthropicCompletion(BaseLLM): completion_params = self._prepare_completion_params( formatted_messages, system_message, tools ) - print("completion_params", completion_params) # Handle streaming vs non-streaming if self.stream: @@ -300,7 +298,6 @@ class AnthropicCompletion(BaseLLM): ) -> str | Any: """Handle non-streaming message completion.""" try: - print("params", params) response: Message = self.client.messages.create(**params) except Exception as e: From 44bbccdb751394aba7b940497b3efde3c35ddcae Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Wed, 15 Oct 2025 11:12:35 -0700 Subject: [PATCH 07/11] test: add fixture to mock ANTHROPIC_API_KEY for tests - Introduced a pytest fixture to automatically mock the ANTHROPIC_API_KEY environment variable for all tests in the test_anthropic.py module. - This change ensures that tests can run without requiring a real API key, improving test isolation and reliability. --- lib/crewai/tests/llms/anthropic/test_anthropic.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 7d0780561..90a0eb766 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -5,11 +5,16 @@ from unittest.mock import patch, MagicMock import pytest from crewai.llm import LLM -from crewai.llms.providers.anthropic.completion import AnthropicCompletion from crewai.crew import Crew from crewai.agent import Agent from crewai.task import Task -from crewai.cli.constants import DEFAULT_LLM_MODEL + + +@pytest.fixture(autouse=True) +def mock_anthropic_api_key(): + """Automatically mock ANTHROPIC_API_KEY for all tests in this module.""" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + yield def test_anthropic_completion_is_used_when_anthropic_provider(): From c5455142c3ccb0f837c7e4a97f8e996f6d159f1f Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Wed, 15 Oct 2025 18:57:27 -0700 Subject: [PATCH 08/11] feat: enhance GeminiCompletion class with additional client parameters and refactor client initialization - Added support for client_params in the GeminiCompletion class to allow for additional client configuration. - Refactored client initialization to use a dedicated method for retrieving client parameters, improving code organization and clarity. - Introduced comprehensive test cases to validate the functionality of the GeminiCompletion class, ensuring proper handling of tool use and parameter management. --- .../llms/providers/gemini/completion.py | 107 ++- lib/crewai/tests/llms/google/test_google.py | 644 ++++++++++++++++++ 2 files changed, 730 insertions(+), 21 deletions(-) create mode 100644 lib/crewai/tests/llms/google/test_google.py diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index 7012e5ca0..987a55b49 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -11,9 +11,9 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( try: - from google import genai # type: ignore - from google.genai import types # type: ignore - from google.genai.errors import APIError # type: ignore + from google import genai + from google.genai import types + from google.genai.errors import APIError except ImportError: raise ImportError( "Google Gen AI native provider not available, to install: `uv add google-genai`" @@ -40,6 +40,7 @@ class GeminiCompletion(BaseLLM): stop_sequences: list[str] | None = None, stream: bool = False, safety_settings: dict[str, Any] | None = None, + client_params: dict[str, Any] | None = None, **kwargs, ): """Initialize Google Gemini chat completion client. @@ -56,35 +57,27 @@ class GeminiCompletion(BaseLLM): stop_sequences: Stop sequences stream: Enable streaming responses safety_settings: Safety filter settings + client_params: Additional parameters to pass to the Google Gen AI Client constructor. + Supports parameters like http_options, credentials, debug_config, etc. **kwargs: Additional parameters """ super().__init__( model=model, temperature=temperature, stop=stop_sequences or [], **kwargs ) - # Get API configuration + # Store client params for later use + self.client_params = client_params or {} + + # Get API configuration with environment variable fallbacks self.api_key = ( api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") ) self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" - # Initialize client based on available configuration - if self.project: - # Use Vertex AI - self.client = genai.Client( - vertexai=True, - project=self.project, - location=self.location, - ) - elif self.api_key: - # Use Gemini Developer API - self.client = genai.Client(api_key=self.api_key) - else: - raise ValueError( - "Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or " - "GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set" - ) + use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" + + self.client = self._initialize_client(use_vertexai) # Store completion parameters self.top_p = top_p @@ -99,6 +92,78 @@ class GeminiCompletion(BaseLLM): self.is_gemini_1_5 = "gemini-1.5" in model.lower() self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 + def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: + """Initialize the Google Gen AI client with proper parameter handling. + + Args: + use_vertexai: Whether to use Vertex AI (from environment variable) + + Returns: + Initialized Google Gen AI Client + """ + client_params = {} + + if self.client_params: + client_params.update(self.client_params) + + if use_vertexai or self.project: + client_params.update( + { + "vertexai": True, + "project": self.project, + "location": self.location, + } + ) + + client_params.pop("api_key", None) + + elif self.api_key: + client_params["api_key"] = self.api_key + + client_params.pop("vertexai", None) + client_params.pop("project", None) + client_params.pop("location", None) + + else: + try: + return genai.Client(**client_params) + except Exception as e: + raise ValueError( + "Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or " + "GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set" + ) from e + + return genai.Client(**client_params) + + def _get_client_params(self) -> dict[str, Any]: + """Get client parameters for compatibility with base class. + + Note: This method is kept for compatibility but the Google Gen AI SDK + uses a different initialization pattern via the Client constructor. + """ + params = {} + + if ( + hasattr(self, "client") + and hasattr(self.client, "vertexai") + and self.client.vertexai + ): + # Vertex AI configuration + params.update( + { + "vertexai": True, + "project": self.project, + "location": self.location, + } + ) + elif self.api_key: + params["api_key"] = self.api_key + + if self.client_params: + params.update(self.client_params) + + return params + def call( self, messages: str | list[dict[str, str]], @@ -427,7 +492,7 @@ class GeminiCompletion(BaseLLM): def supports_stop_words(self) -> bool: """Check if the model supports stop words.""" - return self._supports_stop_words_implementation() + return True def get_context_window_size(self) -> int: """Get the context window size for the model.""" diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py new file mode 100644 index 000000000..ce6219a0c --- /dev/null +++ b/lib/crewai/tests/llms/google/test_google.py @@ -0,0 +1,644 @@ +import os +import sys +import types +from unittest.mock import patch, MagicMock +import pytest + +from crewai.llm import LLM +from crewai.crew import Crew +from crewai.agent import Agent +from crewai.task import Task + + +@pytest.fixture(autouse=True) +def mock_anthropic_api_key(): + """Automatically mock ANTHROPIC_API_KEY for all tests in this module.""" + with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}): + yield + + +def test_gemini_completion_is_used_when_google_provider(): + """ + Test that GeminiCompletion from completion.py is used when LLM uses provider 'google' + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + assert llm.__class__.__name__ == "GeminiCompletion" + assert llm.provider == "google" + assert llm.model == "gemini-2.0-flash-001" + + +def test_gemini_completion_is_used_when_gemini_provider(): + """ + Test that GeminiCompletion is used when provider is 'gemini' + """ + llm = LLM(model="gemini/gemini-2.0-flash-001") + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.provider == "gemini" + assert llm.model == "gemini-2.0-flash-001" + + + + +def test_gemini_tool_use_conversation_flow(): + """ + Test that the Gemini completion properly handles tool use conversation flow + """ + from unittest.mock import Mock, patch + from crewai.llms.providers.gemini.completion import GeminiCompletion + + # Create GeminiCompletion instance + completion = GeminiCompletion(model="gemini-2.0-flash-001") + + # Mock tool function + def mock_weather_tool(location: str) -> str: + return f"The weather in {location} is sunny and 75°F" + + available_functions = {"get_weather": mock_weather_tool} + + # Mock the Google Gemini client responses + with patch.object(completion.client.models, 'generate_content') as mock_generate: + # Mock function call in response + mock_function_call = Mock() + mock_function_call.name = "get_weather" + mock_function_call.args = {"location": "San Francisco"} + + mock_part = Mock() + mock_part.function_call = mock_function_call + + mock_content = Mock() + mock_content.parts = [mock_part] + + mock_candidate = Mock() + mock_candidate.content = mock_content + + mock_response = Mock() + mock_response.candidates = [mock_candidate] + mock_response.text = "Based on the weather data, it's a beautiful day in San Francisco with sunny skies and 75°F temperature." + mock_response.usage_metadata = Mock() + mock_response.usage_metadata.prompt_token_count = 100 + mock_response.usage_metadata.candidates_token_count = 50 + mock_response.usage_metadata.total_token_count = 150 + + mock_generate.return_value = mock_response + + # Test the call + messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}] + result = completion.call( + messages=messages, + available_functions=available_functions + ) + + # Verify the tool was executed and returned the result + assert result == "The weather in San Francisco is sunny and 75°F" + + # Verify that the API was called + assert mock_generate.called + + +def test_gemini_completion_module_is_imported(): + """ + Test that the completion module is properly imported when using Google provider + """ + module_name = "crewai.llms.providers.gemini.completion" + + # Remove module from cache if it exists + if module_name in sys.modules: + del sys.modules[module_name] + + # Create LLM instance - this should trigger the import + LLM(model="google/gemini-2.0-flash-001") + + # Verify the module was imported + assert module_name in sys.modules + completion_mod = sys.modules[module_name] + assert isinstance(completion_mod, types.ModuleType) + + # Verify the class exists in the module + assert hasattr(completion_mod, 'GeminiCompletion') + + +def test_fallback_to_litellm_when_native_gemini_fails(): + """ + Test that LLM falls back to LiteLLM when native Gemini completion fails + """ + # Mock the _get_native_provider to return a failing class + with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider: + + class FailingCompletion: + def __init__(self, *args, **kwargs): + raise Exception("Native Google Gen AI SDK failed") + + mock_get_provider.return_value = FailingCompletion + + # This should fall back to LiteLLM + llm = LLM(model="google/gemini-2.0-flash-001") + + # Check that it's using LiteLLM + assert hasattr(llm, 'is_litellm') + assert llm.is_litellm == True + + +def test_gemini_completion_initialization_parameters(): + """ + Test that GeminiCompletion is initialized with correct parameters + """ + llm = LLM( + model="google/gemini-2.0-flash-001", + temperature=0.7, + max_output_tokens=2000, + top_p=0.9, + top_k=40, + api_key="test-key" + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.model == "gemini-2.0-flash-001" + assert llm.temperature == 0.7 + assert llm.max_output_tokens == 2000 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + + +def test_gemini_specific_parameters(): + """ + Test Gemini-specific parameters like stop_sequences, streaming, and safety settings + """ + safety_settings = { + "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE", + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE" + } + + llm = LLM( + model="google/gemini-2.0-flash-001", + stop_sequences=["Human:", "Assistant:"], + stream=True, + safety_settings=safety_settings, + project="test-project", + location="us-central1" + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stream == True + assert llm.safety_settings == safety_settings + assert llm.project == "test-project" + assert llm.location == "us-central1" + + +def test_gemini_completion_call(): + """ + Test that GeminiCompletion call method works + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the call method on the instance + with patch.object(llm, 'call', return_value="Hello! I'm Gemini, ready to help.") as mock_call: + result = llm.call("Hello, how are you?") + + assert result == "Hello! I'm Gemini, ready to help." + mock_call.assert_called_once_with("Hello, how are you?") + + +def test_gemini_completion_called_during_crew_execution(): + """ + Test that GeminiCompletion.call is actually invoked when running a crew + """ + # Create the LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the call method on the specific instance + with patch.object(gemini_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call: + + # Create agent with explicit LLM configuration + agent = Agent( + role="Research Assistant", + goal="Find population info", + backstory="You research populations.", + llm=gemini_llm, + ) + + task = Task( + description="Find Tokyo population", + expected_output="Population number", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + result = crew.kickoff() + + # Verify mock was called + assert mock_call.called + assert "14 million" in str(result) + + +def test_gemini_completion_call_arguments(): + """ + Test that GeminiCompletion.call is invoked with correct arguments + """ + # Create LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the instance method + with patch.object(gemini_llm, 'call') as mock_call: + mock_call.return_value = "Task completed successfully." + + agent = Agent( + role="Test Agent", + goal="Complete a simple task", + backstory="You are a test agent.", + llm=gemini_llm # Use same instance + ) + + task = Task( + description="Say hello world", + expected_output="Hello world", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + # Verify call was made + assert mock_call.called + + # Check the arguments passed to the call method + call_args = mock_call.call_args + assert call_args is not None + + # The first argument should be the messages + messages = call_args[0][0] # First positional argument + assert isinstance(messages, (str, list)) + + # Verify that the task description appears in the messages + if isinstance(messages, str): + assert "hello world" in messages.lower() + elif isinstance(messages, list): + message_content = str(messages).lower() + assert "hello world" in message_content + + +def test_multiple_gemini_calls_in_crew(): + """ + Test that GeminiCompletion.call is invoked multiple times for multiple tasks + """ + # Create LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the instance method + with patch.object(gemini_llm, 'call') as mock_call: + mock_call.return_value = "Task completed." + + agent = Agent( + role="Multi-task Agent", + goal="Complete multiple tasks", + backstory="You can handle multiple tasks.", + llm=gemini_llm # Use same instance + ) + + task1 = Task( + description="First task", + expected_output="First result", + agent=agent, + ) + + task2 = Task( + description="Second task", + expected_output="Second result", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task1, task2] + ) + crew.kickoff() + + # Verify multiple calls were made + assert mock_call.call_count >= 2 # At least one call per task + + # Verify each call had proper arguments + for call in mock_call.call_args_list: + assert len(call[0]) > 0 # Has positional arguments + messages = call[0][0] + assert messages is not None + + +def test_gemini_completion_with_tools(): + """ + Test that GeminiCompletion.call is invoked with tools when agent has tools + """ + from crewai.tools import tool + + @tool + def sample_tool(query: str) -> str: + """A sample tool for testing""" + return f"Tool result for: {query}" + + # Create LLM instance first + gemini_llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the instance method + with patch.object(gemini_llm, 'call') as mock_call: + mock_call.return_value = "Task completed with tools." + + agent = Agent( + role="Tool User", + goal="Use tools to complete tasks", + backstory="You can use tools.", + llm=gemini_llm, # Use same instance + tools=[sample_tool] + ) + + task = Task( + description="Use the sample tool", + expected_output="Tool usage result", + agent=agent, + ) + + crew = Crew(agents=[agent], tasks=[task]) + crew.kickoff() + + assert mock_call.called + + call_args = mock_call.call_args + call_kwargs = call_args[1] if len(call_args) > 1 else {} + + if 'tools' in call_kwargs: + assert call_kwargs['tools'] is not None + assert len(call_kwargs['tools']) > 0 + + +def test_gemini_raises_error_when_model_not_supported(): + """Test that GeminiCompletion raises ValueError when model not supported""" + + # Mock the Google client to raise an error + with patch('crewai.llms.providers.gemini.completion.genai') as mock_genai: + mock_client = MagicMock() + mock_genai.Client.return_value = mock_client + + # Mock the error that Google would raise for unsupported models + from google.genai.errors import ClientError # type: ignore + mock_client.models.generate_content.side_effect = ClientError( + code=404, + response_json={ + 'error': { + 'code': 404, + 'message': 'models/model-doesnt-exist is not found for API version v1beta, or is not supported for generateContent.', + 'status': 'NOT_FOUND' + } + } + ) + + llm = LLM(model="google/model-doesnt-exist") + + with pytest.raises(Exception): # Should raise some error for unsupported model + llm.call("Hello") + + +def test_gemini_vertex_ai_setup(): + """ + Test that Vertex AI configuration is properly handled + """ + with patch.dict(os.environ, { + "GOOGLE_CLOUD_PROJECT": "test-project", + "GOOGLE_CLOUD_LOCATION": "us-west1" + }): + llm = LLM( + model="google/gemini-2.0-flash-001", + project="test-project", + location="us-west1" + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + + assert llm.project == "test-project" + assert llm.location == "us-west1" + + +def test_gemini_api_key_configuration(): + """ + Test that API key configuration works for both GOOGLE_API_KEY and GEMINI_API_KEY + """ + # Test with GOOGLE_API_KEY + with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): + llm = LLM(model="google/gemini-2.0-flash-001") + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + assert llm.api_key == "test-google-key" + + # Test with GEMINI_API_KEY + with patch.dict(os.environ, {"GEMINI_API_KEY": "test-gemini-key"}, clear=True): + llm = LLM(model="google/gemini-2.0-flash-001") + + assert isinstance(llm, GeminiCompletion) + assert llm.api_key == "test-gemini-key" + + +def test_gemini_model_capabilities(): + """ + Test that model capabilities are correctly identified + """ + # Test Gemini 2.0 model + llm_2_0 = LLM(model="google/gemini-2.0-flash-001") + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm_2_0, GeminiCompletion) + assert llm_2_0.is_gemini_2 == True + assert llm_2_0.supports_tools == True + + # Test Gemini 1.5 model + llm_1_5 = LLM(model="google/gemini-1.5-pro") + assert isinstance(llm_1_5, GeminiCompletion) + assert llm_1_5.is_gemini_1_5 == True + assert llm_1_5.supports_tools == True + + +def test_gemini_generation_config(): + """ + Test that generation config is properly prepared + """ + llm = LLM( + model="google/gemini-2.0-flash-001", + temperature=0.7, + top_p=0.9, + top_k=40, + max_output_tokens=1000 + ) + + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion) + + # Test config preparation + config = llm._prepare_generation_config() + + # Verify config has the expected parameters + assert hasattr(config, 'temperature') or 'temperature' in str(config) + assert hasattr(config, 'top_p') or 'top_p' in str(config) + assert hasattr(config, 'top_k') or 'top_k' in str(config) + assert hasattr(config, 'max_output_tokens') or 'max_output_tokens' in str(config) + + +def test_gemini_model_detection(): + """ + Test that various Gemini model formats are properly detected + """ + # Test Gemini model naming patterns that actually work with provider detection + gemini_test_cases = [ + "google/gemini-2.0-flash-001", + "gemini/gemini-2.0-flash-001", + "google/gemini-1.5-pro", + "gemini/gemini-1.5-flash" + ] + + for model_name in gemini_test_cases: + llm = LLM(model=model_name) + from crewai.llms.providers.gemini.completion import GeminiCompletion + assert isinstance(llm, GeminiCompletion), f"Failed for model: {model_name}" + + +def test_gemini_supports_stop_words(): + """ + Test that Gemini models support stop sequences + """ + llm = LLM(model="google/gemini-2.0-flash-001") + assert llm.supports_stop_words() == True + + +def test_gemini_context_window_size(): + """ + Test that Gemini models return correct context window sizes + """ + # Test Gemini 2.0 Flash + llm_2_0 = LLM(model="google/gemini-2.0-flash-001") + context_size_2_0 = llm_2_0.get_context_window_size() + assert context_size_2_0 > 500000 # Should be substantial (1M tokens) + + # Test Gemini 1.5 Pro + llm_1_5 = LLM(model="google/gemini-1.5-pro") + context_size_1_5 = llm_1_5.get_context_window_size() + assert context_size_1_5 > 1000000 # Should be very large (2M tokens) + + +def test_gemini_message_formatting(): + """ + Test that messages are properly formatted for Gemini API + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Test message formatting + test_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + + formatted_contents, system_instruction = llm._format_messages_for_gemini(test_messages) + + # System message should be extracted + assert system_instruction == "You are a helpful assistant." + + # Remaining messages should be Content objects + assert len(formatted_contents) >= 3 # Should have user, model, user messages + + # First content should be user role + assert formatted_contents[0].role == "user" + # Second should be model (converted from assistant) + assert formatted_contents[1].role == "model" + + +def test_gemini_streaming_parameter(): + """ + Test that streaming parameter is properly handled + """ + # Test non-streaming + llm_no_stream = LLM(model="google/gemini-2.0-flash-001", stream=False) + assert llm_no_stream.stream == False + + # Test streaming + llm_stream = LLM(model="google/gemini-2.0-flash-001", stream=True) + assert llm_stream.stream == True + + +def test_gemini_tool_conversion(): + """ + Test that tools are properly converted to Gemini format + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock tool in CrewAI format + crewai_tools = [{ + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + } + }] + + # Test tool conversion + gemini_tools = llm._convert_tools_for_interference(crewai_tools) + + assert len(gemini_tools) == 1 + # Gemini tools are Tool objects with function_declarations + assert hasattr(gemini_tools[0], 'function_declarations') + assert len(gemini_tools[0].function_declarations) == 1 + + func_decl = gemini_tools[0].function_declarations[0] + assert func_decl.name == "test_tool" + assert func_decl.description == "A test tool" + + +def test_gemini_environment_variable_api_key(): + """ + Test that Google API key is properly loaded from environment + """ + with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): + llm = LLM(model="google/gemini-2.0-flash-001") + + assert llm.client is not None + assert hasattr(llm.client, 'models') + assert llm.api_key == "test-google-key" + + +def test_gemini_token_usage_tracking(): + """ + Test that token usage is properly tracked for Gemini responses + """ + llm = LLM(model="google/gemini-2.0-flash-001") + + # Mock the Gemini response with usage information + with patch.object(llm.client.models, 'generate_content') as mock_generate: + mock_response = MagicMock() + mock_response.text = "test response" + mock_response.candidates = [] + mock_response.usage_metadata = MagicMock( + prompt_token_count=50, + candidates_token_count=25, + total_token_count=75 + ) + mock_generate.return_value = mock_response + + result = llm.call("Hello") + + # Verify the response + assert result == "test response" + + # Verify token usage was extracted + usage = llm._extract_token_usage(mock_response) + assert usage["prompt_token_count"] == 50 + assert usage["candidates_token_count"] == 25 + assert usage["total_token_count"] == 75 + assert usage["total_tokens"] == 75 From 6150a358a312fe2ed134bbb07a60bda018fddd92 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Tue, 14 Oct 2025 15:34:28 -0700 Subject: [PATCH 09/11] feat: enhance AnthropicCompletion class with additional client parameters and tool handling - Added support for client_params in the AnthropicCompletion class to allow for additional client configuration. - Refactored client initialization to use a dedicated method for retrieving client parameters. - Implemented a new method to handle tool use conversation flow, ensuring proper execution and response handling. - Introduced comprehensive test cases to validate the functionality of the AnthropicCompletion class, including tool use scenarios and parameter handling. --- lib/crewai/src/crewai/llms/providers/anthropic/completion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index ffcaf3077..a90f06573 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -124,6 +124,7 @@ class AnthropicCompletion(BaseLLM): Chat completion response or tool call result """ try: + print("we are calling", messages) # Emit call started event self._emit_call_started_event( messages=messages, @@ -143,6 +144,7 @@ class AnthropicCompletion(BaseLLM): completion_params = self._prepare_completion_params( formatted_messages, system_message, tools ) + print("completion_params", completion_params) # Handle streaming vs non-streaming if self.stream: @@ -298,6 +300,7 @@ class AnthropicCompletion(BaseLLM): ) -> str | Any: """Handle non-streaming message completion.""" try: + print("params", params) response: Message = self.client.messages.create(**params) except Exception as e: From ba30374ac4bf4199e8823a712d41bba5ee094921 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Tue, 14 Oct 2025 15:36:30 -0700 Subject: [PATCH 10/11] drop print statements --- lib/crewai/src/crewai/llms/providers/anthropic/completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index a90f06573..ffcaf3077 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -124,7 +124,6 @@ class AnthropicCompletion(BaseLLM): Chat completion response or tool call result """ try: - print("we are calling", messages) # Emit call started event self._emit_call_started_event( messages=messages, @@ -144,7 +143,6 @@ class AnthropicCompletion(BaseLLM): completion_params = self._prepare_completion_params( formatted_messages, system_message, tools ) - print("completion_params", completion_params) # Handle streaming vs non-streaming if self.stream: @@ -300,7 +298,6 @@ class AnthropicCompletion(BaseLLM): ) -> str | Any: """Handle non-streaming message completion.""" try: - print("params", params) response: Message = self.client.messages.create(**params) except Exception as e: From fbd72ded440c5226bd45fdf24d3830ccd5a08fb2 Mon Sep 17 00:00:00 2001 From: lorenzejay Date: Thu, 16 Oct 2025 10:49:58 -0700 Subject: [PATCH 11/11] no runners --- lib/crewai/runner.py | 72 -------------------------------------------- 1 file changed, 72 deletions(-) delete mode 100644 lib/crewai/runner.py diff --git a/lib/crewai/runner.py b/lib/crewai/runner.py deleted file mode 100644 index eab7e0f89..000000000 --- a/lib/crewai/runner.py +++ /dev/null @@ -1,72 +0,0 @@ -from crewai_tools import EXASearchTool - -from crewai import LLM, Agent, Crew, Task -import os - - -llm = LLM( - model="anthropic/claude-3-5-sonnet-20241022", - api_key=os.getenv("ANTHROPIC_API_KEY"), -) -agent = Agent( - role="researcher", - backstory="A researcher who can research the web", - goal="Research the web", - tools=[EXASearchTool()], - llm=llm, -) - -task = Task( - description="Research the web based on the query: {query}", - expected_output="A list of 10 bullet points of the most relevant information about the web", - agent=agent, -) - -crew = Crew( - agents=[agent], - tasks=[task], - verbose=True, - tracing=True, -) - -# result = crew.kickoff(inputs={"query": "What are ai agents?"}) -# print("result", result) -# print("usage_metrics", result.token_usage) - - -def anthropic_tool_use_runner(): - def get_weather(location: str) -> str: - return f"The weather in {location} is sunny" - - llm = LLM( - model="anthropic/claude-3-5-sonnet-20241022", - api_key=os.getenv("ANTHROPIC_API_KEY"), - ) - result = llm.call( - messages=[{"role": "user", "content": "What is the weather in San Francisco?"}], - available_functions={"get_weather": get_weather}, - tools=[ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the weather in a location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather for", - } - }, - "required": ["location"], - }, - }, - } - ], - ) - print("anthropic tool use result", result) - - -if __name__ == "__main__": - anthropic_tool_use_runner()