mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
fix: Improve model name validation and fix syntax errors
- Fix docstring placement and type hints - Add proper model name validation with clear error messages - Organize tests into a class and add edge cases Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -117,18 +117,42 @@ def suppress_warnings():
|
|||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
|
"""LLM class for handling model interactions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model identifier; should not start with 'models/'.
|
||||||
|
Examples: 'gemini/gemini-1.5-pro', 'anthropic/claude-3'
|
||||||
|
timeout: Optional timeout for model calls
|
||||||
|
temperature: Optional temperature parameter
|
||||||
|
max_tokens: Optional maximum tokens for completion
|
||||||
|
max_completion_tokens: Optional maximum completion tokens
|
||||||
|
logprobs: Optional log probabilities
|
||||||
|
top_p: Optional nucleus sampling parameter
|
||||||
|
n: Optional number of completions
|
||||||
|
stop: Optional stop sequences
|
||||||
|
presence_penalty: Optional presence penalty
|
||||||
|
frequency_penalty: Optional frequency penalty
|
||||||
|
logit_bias: Optional token biasing
|
||||||
|
user: Optional user identifier
|
||||||
|
response_format: Optional response format configuration
|
||||||
|
seed: Optional random seed
|
||||||
|
tools: Optional list of tools
|
||||||
|
tool_choice: Optional tool choice configuration
|
||||||
|
api_base: Optional API base URL
|
||||||
|
api_key: Optional API key
|
||||||
|
api_version: Optional API version
|
||||||
|
base_url: Optional base URL
|
||||||
|
top_logprobs: Optional top log probabilities
|
||||||
|
callbacks: Optional list of callbacks
|
||||||
|
reasoning_effort: Optional reasoning effort level
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model name starts with 'models/' or is empty
|
||||||
|
TypeError: If model is not a string
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Union[str, 'BaseLanguageModel'],
|
model: str,
|
||||||
"""Initialize LLM instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model identifier; should not start with 'models/'.
|
|
||||||
Examples: 'gemini/gemini-1.5-pro', 'anthropic/claude-3'
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the model name starts with 'models/'.
|
|
||||||
"""
|
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
@@ -155,7 +179,11 @@ class LLM:
|
|||||||
INVALID_MODEL_PREFIX = "models/"
|
INVALID_MODEL_PREFIX = "models/"
|
||||||
|
|
||||||
# Validate model name
|
# Validate model name
|
||||||
if isinstance(model, str) and model.startswith(INVALID_MODEL_PREFIX):
|
if not isinstance(model, str):
|
||||||
|
raise TypeError("Model name must be a string")
|
||||||
|
if not model:
|
||||||
|
raise ValueError("Model name cannot be empty")
|
||||||
|
if model.startswith(INVALID_MODEL_PREFIX):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Invalid model name "{model}": Model names should not start with "{INVALID_MODEL_PREFIX}". '
|
f'Invalid model name "{model}": Model names should not start with "{INVALID_MODEL_PREFIX}". '
|
||||||
'Use the provider prefix instead (e.g., "gemini/model-name").'
|
'Use the provider prefix instead (e.g., "gemini/model-name").'
|
||||||
|
|||||||
@@ -269,9 +269,9 @@ class TestModelNameValidation:
|
|||||||
|
|
||||||
def test_edge_cases(self):
|
def test_edge_cases(self):
|
||||||
"""Test edge cases for model name validation."""
|
"""Test edge cases for model name validation."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
LLM(model="") # Empty string
|
LLM(model="") # Empty string
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError, match="must be a string"):
|
||||||
LLM(model=None) # None value
|
LLM(model=None) # None value
|
||||||
|
|
||||||
|
|
||||||
@@ -347,13 +347,16 @@ def test_anthropic_model_detection():
|
|||||||
("claude-instant", True),
|
("claude-instant", True),
|
||||||
("claude/v1", True),
|
("claude/v1", True),
|
||||||
("gpt-4", False),
|
("gpt-4", False),
|
||||||
("", False),
|
|
||||||
("anthropomorphic", False), # Should not match partial words
|
("anthropomorphic", False), # Should not match partial words
|
||||||
]
|
]
|
||||||
|
|
||||||
for model, expected in models:
|
for model, expected in models:
|
||||||
llm = LLM(model=model)
|
llm = LLM(model=model)
|
||||||
assert llm.is_anthropic == expected, f"Failed for model: {model}"
|
assert llm._is_anthropic_model(model) == expected, f"Failed for model: {model}"
|
||||||
|
|
||||||
|
# Test empty model name separately since it raises ValueError
|
||||||
|
with pytest.raises(ValueError, match="cannot be empty"):
|
||||||
|
LLM(model="")
|
||||||
|
|
||||||
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message):
|
def test_anthropic_message_formatting(anthropic_llm, system_message, user_message):
|
||||||
"""Test Anthropic message formatting with fixtures."""
|
"""Test Anthropic message formatting with fixtures."""
|
||||||
|
|||||||
Reference in New Issue
Block a user