Fix Anthropic max_tokens issue causing slow execution (issue #3807)

Make max_tokens optional and compute dynamically when not set by user.

Previously, max_tokens defaulted to 4096 and was always passed to the
Anthropic API, causing the model to generate up to 4096 tokens even for
simple queries that should only need a few tokens. This resulted in
extremely slow execution times.

Changes:
- Changed max_tokens parameter from int (default 4096) to int | None (default None)
- Added dynamic computation in _prepare_completion_params():
  - Default: 1024 tokens (much more reasonable for most queries)
  - Large context models (200k+): 2048 tokens
  - User-specified values are always respected
- Updated docstring to reflect that max_tokens is now optional
- Added comprehensive tests covering:
  - Explicit max_tokens values are passed through unchanged
  - Default behavior computes reasonable max_tokens dynamically
  - max_tokens=None uses dynamic computation
  - Dynamic values are appropriate for model context window size
  - User-provided values are always respected

This fix aligns with the v0.203.1 behavior where max_tokens was optional
and only passed when explicitly set, while maintaining compatibility with
the Anthropic SDK requirement that max_tokens must be provided.

Fixes #3807

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-10-29 00:26:09 +00:00
parent 70b083945f
commit 02f9a36acb
2 changed files with 116 additions and 4 deletions

View File

@@ -36,7 +36,7 @@ class AnthropicCompletion(BaseLLM):
timeout: float | None = None, timeout: float | None = None,
max_retries: int = 2, max_retries: int = 2,
temperature: float | None = None, temperature: float | None = None,
max_tokens: int = 4096, # Required for Anthropic max_tokens: int | None = None, # Optional, computed dynamically if not set
top_p: float | None = None, top_p: float | None = None,
stop_sequences: list[str] | None = None, stop_sequences: list[str] | None = None,
stream: bool = False, stream: bool = False,
@@ -52,7 +52,8 @@ class AnthropicCompletion(BaseLLM):
timeout: Request timeout in seconds timeout: Request timeout in seconds
max_retries: Maximum number of retries max_retries: Maximum number of retries
temperature: Sampling temperature (0-1) temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in response (required for Anthropic) max_tokens: Maximum tokens in response. If not set, will be computed
dynamically based on context window size (recommended for most use cases)
top_p: Nucleus sampling parameter top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses stream: Enable streaming responses
@@ -72,7 +73,7 @@ class AnthropicCompletion(BaseLLM):
self.client = Anthropic(**self._get_client_params()) self.client = Anthropic(**self._get_client_params())
# Store completion parameters # Store completion parameters
self.max_tokens = max_tokens self.max_tokens = max_tokens # Can be None, will be computed dynamically
self.top_p = top_p self.top_p = top_p
self.stream = stream self.stream = stream
self.stop_sequences = stop_sequences or [] self.stop_sequences = stop_sequences or []
@@ -178,10 +179,19 @@ class AnthropicCompletion(BaseLLM):
Returns: Returns:
Parameters dictionary for Anthropic API Parameters dictionary for Anthropic API
""" """
max_tokens = self.max_tokens
if max_tokens is None:
# while still allowing enough tokens for most responses
max_tokens = 1024
context_window = self.get_context_window_size()
if context_window > 100000: # For Claude models with 200k+ context
max_tokens = 2048
params = { params = {
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"max_tokens": self.max_tokens, "max_tokens": max_tokens,
"stream": self.stream, "stream": self.stream,
} }

View File

@@ -664,3 +664,105 @@ def test_anthropic_token_usage_tracking():
assert usage["input_tokens"] == 50 assert usage["input_tokens"] == 50
assert usage["output_tokens"] == 25 assert usage["output_tokens"] == 25
assert usage["total_tokens"] == 75 assert usage["total_tokens"] == 75
def test_anthropic_max_tokens_explicit():
"""
Test that explicit max_tokens is passed through to the API
"""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=4096)
with patch.object(llm.client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
mock_create.return_value = mock_response
llm.call("Hello")
call_kwargs = mock_create.call_args[1]
assert call_kwargs["max_tokens"] == 4096
def test_anthropic_max_tokens_default_computed():
"""
Test that max_tokens is computed dynamically when not explicitly set
"""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
with patch.object(llm.client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
mock_create.return_value = mock_response
llm.call("Hello")
call_kwargs = mock_create.call_args[1]
assert "max_tokens" in call_kwargs
assert call_kwargs["max_tokens"] is not None
assert call_kwargs["max_tokens"] < 4096
def test_anthropic_max_tokens_none_uses_dynamic_default():
"""
Test that max_tokens=None results in dynamic computation
"""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=None)
assert llm.max_tokens is None
with patch.object(llm.client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
mock_create.return_value = mock_response
llm.call("Hello")
call_kwargs = mock_create.call_args[1]
assert "max_tokens" in call_kwargs
assert call_kwargs["max_tokens"] is not None
assert call_kwargs["max_tokens"] < 4096
def test_anthropic_max_tokens_dynamic_for_large_context():
"""
Test that dynamic max_tokens is larger for models with large context windows
"""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
with patch.object(llm.client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
mock_create.return_value = mock_response
llm.call("Hello")
call_kwargs = mock_create.call_args[1]
computed_max_tokens = call_kwargs["max_tokens"]
assert computed_max_tokens >= 1024
assert computed_max_tokens <= 2048
def test_anthropic_max_tokens_respects_user_value():
"""
Test that user-provided max_tokens is always respected
"""
test_values = [512, 1024, 2048, 4096, 8192]
for max_tokens_value in test_values:
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=max_tokens_value)
with patch.object(llm.client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
mock_create.return_value = mock_response
llm.call("Hello")
call_kwargs = mock_create.call_args[1]
assert call_kwargs["max_tokens"] == max_tokens_value