mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix Anthropic max_tokens issue causing slow execution (issue #3807)
Make max_tokens optional and compute dynamically when not set by user. Previously, max_tokens defaulted to 4096 and was always passed to the Anthropic API, causing the model to generate up to 4096 tokens even for simple queries that should only need a few tokens. This resulted in extremely slow execution times. Changes: - Changed max_tokens parameter from int (default 4096) to int | None (default None) - Added dynamic computation in _prepare_completion_params(): - Default: 1024 tokens (much more reasonable for most queries) - Large context models (200k+): 2048 tokens - User-specified values are always respected - Updated docstring to reflect that max_tokens is now optional - Added comprehensive tests covering: - Explicit max_tokens values are passed through unchanged - Default behavior computes reasonable max_tokens dynamically - max_tokens=None uses dynamic computation - Dynamic values are appropriate for model context window size - User-provided values are always respected This fix aligns with the v0.203.1 behavior where max_tokens was optional and only passed when explicitly set, while maintaining compatibility with the Anthropic SDK requirement that max_tokens must be provided. Fixes #3807 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -36,7 +36,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
max_retries: int = 2,
|
max_retries: int = 2,
|
||||||
temperature: float | None = None,
|
temperature: float | None = None,
|
||||||
max_tokens: int = 4096, # Required for Anthropic
|
max_tokens: int | None = None, # Optional, computed dynamically if not set
|
||||||
top_p: float | None = None,
|
top_p: float | None = None,
|
||||||
stop_sequences: list[str] | None = None,
|
stop_sequences: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
@@ -52,7 +52,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
timeout: Request timeout in seconds
|
timeout: Request timeout in seconds
|
||||||
max_retries: Maximum number of retries
|
max_retries: Maximum number of retries
|
||||||
temperature: Sampling temperature (0-1)
|
temperature: Sampling temperature (0-1)
|
||||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
max_tokens: Maximum tokens in response. If not set, will be computed
|
||||||
|
dynamically based on context window size (recommended for most use cases)
|
||||||
top_p: Nucleus sampling parameter
|
top_p: Nucleus sampling parameter
|
||||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
@@ -72,7 +73,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self.client = Anthropic(**self._get_client_params())
|
self.client = Anthropic(**self._get_client_params())
|
||||||
|
|
||||||
# Store completion parameters
|
# Store completion parameters
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens # Can be None, will be computed dynamically
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.stop_sequences = stop_sequences or []
|
self.stop_sequences = stop_sequences or []
|
||||||
@@ -178,10 +179,19 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
Parameters dictionary for Anthropic API
|
Parameters dictionary for Anthropic API
|
||||||
"""
|
"""
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
# while still allowing enough tokens for most responses
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
context_window = self.get_context_window_size()
|
||||||
|
if context_window > 100000: # For Claude models with 200k+ context
|
||||||
|
max_tokens = 2048
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"stream": self.stream,
|
"stream": self.stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -664,3 +664,105 @@ def test_anthropic_token_usage_tracking():
|
|||||||
assert usage["input_tokens"] == 50
|
assert usage["input_tokens"] == 50
|
||||||
assert usage["output_tokens"] == 25
|
assert usage["output_tokens"] == 25
|
||||||
assert usage["total_tokens"] == 75
|
assert usage["total_tokens"] == 75
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_max_tokens_explicit():
|
||||||
|
"""
|
||||||
|
Test that explicit max_tokens is passed through to the API
|
||||||
|
"""
|
||||||
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=4096)
|
||||||
|
|
||||||
|
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text="test response")]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Hello")
|
||||||
|
|
||||||
|
call_kwargs = mock_create.call_args[1]
|
||||||
|
assert call_kwargs["max_tokens"] == 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_max_tokens_default_computed():
|
||||||
|
"""
|
||||||
|
Test that max_tokens is computed dynamically when not explicitly set
|
||||||
|
"""
|
||||||
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||||
|
|
||||||
|
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text="test response")]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Hello")
|
||||||
|
|
||||||
|
call_kwargs = mock_create.call_args[1]
|
||||||
|
assert "max_tokens" in call_kwargs
|
||||||
|
assert call_kwargs["max_tokens"] is not None
|
||||||
|
assert call_kwargs["max_tokens"] < 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_max_tokens_none_uses_dynamic_default():
|
||||||
|
"""
|
||||||
|
Test that max_tokens=None results in dynamic computation
|
||||||
|
"""
|
||||||
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=None)
|
||||||
|
|
||||||
|
assert llm.max_tokens is None
|
||||||
|
|
||||||
|
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text="test response")]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Hello")
|
||||||
|
|
||||||
|
call_kwargs = mock_create.call_args[1]
|
||||||
|
assert "max_tokens" in call_kwargs
|
||||||
|
assert call_kwargs["max_tokens"] is not None
|
||||||
|
assert call_kwargs["max_tokens"] < 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_max_tokens_dynamic_for_large_context():
|
||||||
|
"""
|
||||||
|
Test that dynamic max_tokens is larger for models with large context windows
|
||||||
|
"""
|
||||||
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||||
|
|
||||||
|
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text="test response")]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Hello")
|
||||||
|
|
||||||
|
call_kwargs = mock_create.call_args[1]
|
||||||
|
computed_max_tokens = call_kwargs["max_tokens"]
|
||||||
|
|
||||||
|
assert computed_max_tokens >= 1024
|
||||||
|
assert computed_max_tokens <= 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_max_tokens_respects_user_value():
|
||||||
|
"""
|
||||||
|
Test that user-provided max_tokens is always respected
|
||||||
|
"""
|
||||||
|
test_values = [512, 1024, 2048, 4096, 8192]
|
||||||
|
|
||||||
|
for max_tokens_value in test_values:
|
||||||
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", max_tokens=max_tokens_value)
|
||||||
|
|
||||||
|
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = [MagicMock(text="test response")]
|
||||||
|
mock_response.usage = MagicMock(input_tokens=10, output_tokens=20)
|
||||||
|
mock_create.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Hello")
|
||||||
|
|
||||||
|
call_kwargs = mock_create.call_args[1]
|
||||||
|
assert call_kwargs["max_tokens"] == max_tokens_value
|
||||||
|
|||||||
Reference in New Issue
Block a user