mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix Anthropic max_tokens issue causing slow execution (issue #3807)
Make max_tokens optional and compute dynamically when not set by user. Previously, max_tokens defaulted to 4096 and was always passed to the Anthropic API, causing the model to generate up to 4096 tokens even for simple queries that should only need a few tokens. This resulted in extremely slow execution times. Changes: - Changed max_tokens parameter from int (default 4096) to int | None (default None) - Added dynamic computation in _prepare_completion_params(): - Default: 1024 tokens (much more reasonable for most queries) - Large context models (200k+): 2048 tokens - User-specified values are always respected - Updated docstring to reflect that max_tokens is now optional - Added comprehensive tests covering: - Explicit max_tokens values are passed through unchanged - Default behavior computes reasonable max_tokens dynamically - max_tokens=None uses dynamic computation - Dynamic values are appropriate for model context window size - User-provided values are always respected This fix aligns with the v0.203.1 behavior where max_tokens was optional and only passed when explicitly set, while maintaining compatibility with the Anthropic SDK requirement that max_tokens must be provided. Fixes #3807 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -36,7 +36,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
max_tokens: int | None = None, # Optional, computed dynamically if not set
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
@@ -52,7 +52,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
max_tokens: Maximum tokens in response. If not set, will be computed
|
||||
dynamically based on context window size (recommended for most use cases)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
@@ -72,7 +73,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.max_tokens = max_tokens # Can be None, will be computed dynamically
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
@@ -178,10 +179,19 @@ class AnthropicCompletion(BaseLLM):
|
||||
Returns:
|
||||
Parameters dictionary for Anthropic API
|
||||
"""
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
# while still allowing enough tokens for most responses
|
||||
max_tokens = 1024
|
||||
|
||||
context_window = self.get_context_window_size()
|
||||
if context_window > 100000: # For Claude models with 200k+ context
|
||||
max_tokens = 2048
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
|
||||
@@ -664,3 +664,105 @@ def test_anthropic_token_usage_tracking():
|
||||
assert usage["input_tokens"] == 50
|
||||
assert usage["output_tokens"] == 25
|
||||
assert usage["total_tokens"] == 75
|
||||
|
||||
|
||||
def test_anthropic_max_tokens_explicit():
|
||||
"""
|
||||
Test that explicit max_tokens is passed through to the API
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=4096)
|
||||
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert call_kwargs["max_tokens"] == 4096
|
||||
|
||||
|
||||
def test_anthropic_max_tokens_default_computed():
|
||||
"""
|
||||
Test that max_tokens is computed dynamically when not explicitly set
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert "max_tokens" in call_kwargs
|
||||
assert call_kwargs["max_tokens"] is not None
|
||||
assert call_kwargs["max_tokens"] < 4096
|
||||
|
||||
|
||||
def test_anthropic_max_tokens_none_uses_dynamic_default():
|
||||
"""
|
||||
Test that max_tokens=None results in dynamic computation
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=None)
|
||||
|
||||
assert llm.max_tokens is None
|
||||
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert "max_tokens" in call_kwargs
|
||||
assert call_kwargs["max_tokens"] is not None
|
||||
assert call_kwargs["max_tokens"] < 4096
|
||||
|
||||
|
||||
def test_anthropic_max_tokens_dynamic_for_large_context():
|
||||
"""
|
||||
Test that dynamic max_tokens is larger for models with large context windows
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
computed_max_tokens = call_kwargs["max_tokens"]
|
||||
|
||||
assert computed_max_tokens >= 1024
|
||||
assert computed_max_tokens <= 2048
|
||||
|
||||
|
||||
def test_anthropic_max_tokens_respects_user_value():
|
||||
"""
|
||||
Test that user-provided max_tokens is always respected
|
||||
"""
|
||||
test_values = [512, 1024, 2048, 4096, 8192]
|
||||
|
||||
for max_tokens_value in test_values:
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=max_tokens_value)
|
||||
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert call_kwargs["max_tokens"] == max_tokens_value
|
||||
|
||||
Reference in New Issue
Block a user