diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index fad6f1904..2ef55bd17 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -36,7 +36,7 @@ class AnthropicCompletion(BaseLLM): timeout: float | None = None, max_retries: int = 2, temperature: float | None = None, - max_tokens: int = 4096, # Required for Anthropic + max_tokens: int | None = None, # Optional, computed dynamically if not set top_p: float | None = None, stop_sequences: list[str] | None = None, stream: bool = False, @@ -52,7 +52,8 @@ class AnthropicCompletion(BaseLLM): timeout: Request timeout in seconds max_retries: Maximum number of retries temperature: Sampling temperature (0-1) - max_tokens: Maximum tokens in response (required for Anthropic) + max_tokens: Maximum tokens in response. If not set, will be computed + dynamically based on context window size (recommended for most use cases) top_p: Nucleus sampling parameter stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) stream: Enable streaming responses @@ -72,7 +73,7 @@ class AnthropicCompletion(BaseLLM): self.client = Anthropic(**self._get_client_params()) # Store completion parameters - self.max_tokens = max_tokens + self.max_tokens = max_tokens # Can be None, will be computed dynamically self.top_p = top_p self.stream = stream self.stop_sequences = stop_sequences or [] @@ -178,10 +179,19 @@ class AnthropicCompletion(BaseLLM): Returns: Parameters dictionary for Anthropic API """ + max_tokens = self.max_tokens + if max_tokens is None: + # while still allowing enough tokens for most responses + max_tokens = 1024 + + context_window = self.get_context_window_size() + if context_window > 100000: # For Claude models with 200k+ context + max_tokens = 2048 + params = { "model": self.model, "messages": messages, - "max_tokens": self.max_tokens, + "max_tokens": max_tokens, "stream": self.stream, } diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 37ba366b9..c03f0d3a1 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -664,3 +664,105 @@ def test_anthropic_token_usage_tracking(): assert usage["input_tokens"] == 50 assert usage["output_tokens"] == 25 assert usage["total_tokens"] == 75 + + +def test_anthropic_max_tokens_explicit(): + """ + Test that explicit max_tokens is passed through to the API + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=4096) + + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=10, output_tokens=20) + mock_create.return_value = mock_response + + llm.call("Hello") + + call_kwargs = mock_create.call_args[1] + assert call_kwargs["max_tokens"] == 4096 + + +def test_anthropic_max_tokens_default_computed(): + """ + Test that max_tokens is computed dynamically when not explicitly set + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=10, output_tokens=20) + mock_create.return_value = mock_response + + llm.call("Hello") + + call_kwargs = mock_create.call_args[1] + assert "max_tokens" in call_kwargs + assert call_kwargs["max_tokens"] is not None + assert call_kwargs["max_tokens"] < 4096 + + +def test_anthropic_max_tokens_none_uses_dynamic_default(): + """ + Test that max_tokens=None results in dynamic computation + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=None) + + assert llm.max_tokens is None + + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=10, output_tokens=20) + mock_create.return_value = mock_response + + llm.call("Hello") + + call_kwargs = mock_create.call_args[1] + assert "max_tokens" in call_kwargs + assert call_kwargs["max_tokens"] is not None + assert call_kwargs["max_tokens"] < 4096 + + +def test_anthropic_max_tokens_dynamic_for_large_context(): + """ + Test that dynamic max_tokens is larger for models with large context windows + """ + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") + + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=10, output_tokens=20) + mock_create.return_value = mock_response + + llm.call("Hello") + + call_kwargs = mock_create.call_args[1] + computed_max_tokens = call_kwargs["max_tokens"] + + assert computed_max_tokens >= 1024 + assert computed_max_tokens <= 2048 + + +def test_anthropic_max_tokens_respects_user_value(): + """ + Test that user-provided max_tokens is always respected + """ + test_values = [512, 1024, 2048, 4096, 8192] + + for max_tokens_value in test_values: + llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=max_tokens_value) + + with patch.object(llm.client.messages, 'create') as mock_create: + mock_response = MagicMock() + mock_response.content = [MagicMock(text="test response")] + mock_response.usage = MagicMock(input_tokens=10, output_tokens=20) + mock_create.return_value = mock_response + + llm.call("Hello") + + call_kwargs = mock_create.call_args[1] + assert call_kwargs["max_tokens"] == max_tokens_value