chore: fix attr ref

This commit is contained in:
Greyson LaLonde
2025-11-11 00:02:34 -05:00
parent 67e39073c7
commit 6fb13ee3e0
10 changed files with 98 additions and 107 deletions

View File

@@ -14,7 +14,7 @@ import re
from typing import TYPE_CHECKING, Any, ClassVar, Final
import httpx
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
@@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: A list of stop sequences that the LLM should use to stop generation.
additional_params: Additional provider-specific parameters.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(
arbitrary_types_allowed=True, extra="allow"
extra="allow", populate_by_name=True
)
# Core fields
@@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
stop: list[str] = Field(
default_factory=list,
description="Stop sequences for generation",
validation_alias="stop_sequences",
alias="stop_sequences",
)
# Internal fields
@@ -90,7 +89,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
default=False, description="Whether this instance uses LiteLLM"
)
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field(
None, description="HTTP request/response interceptor"
default=None, description="HTTP request/response interceptor"
)
_token_usage: dict[str, int] = {
"total_tokens": 0,
@@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
"cached_prompt_tokens": 0,
}
@field_validator("stop", mode="before")
@classmethod
def _normalize_stop(cls, value: Any) -> list[str]:
"""Normalize stop sequences to a list.
Args:
value: Stop sequences as string, list, or None
Returns:
Normalized list of stop sequences
"""
if value is None:
return []
if isinstance(value, str):
return [value]
if isinstance(value, list):
return value
return []
@model_validator(mode="before")
@classmethod
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:

View File

@@ -1,6 +1,31 @@
from typing import Literal, TypeAlias
SupportedNativeProviders: TypeAlias = Literal[
"openai",
"anthropic",
"claude",
"azure",
"azure_openai",
"google",
"gemini",
"bedrock",
"aws",
]
SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [
"openai",
"anthropic",
"claude",
"azure",
"azure_openai",
"google",
"gemini",
"bedrock",
"aws",
]
OpenAIModels: TypeAlias = Literal[
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
@@ -556,3 +581,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
"qwen.qwen3-coder-30b-a3b-v1:0",
"twelvelabs.pegasus-1-2-v1:0",
]
SupportedModels: TypeAlias = (
OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels
)

View File

@@ -7,23 +7,20 @@ based on the model parameter at instantiation time.
from __future__ import annotations
import logging
from typing import Any
from typing import Any, cast
from pydantic._internal._model_construction import ModelMetaclass
# Provider constants imported from crewai.llm.constants
SUPPORTED_NATIVE_PROVIDERS: list[str] = [
"openai",
"anthropic",
"claude",
"azure",
"azure_openai",
"google",
"gemini",
"bedrock",
"aws",
]
from crewai.llm.constants import (
ANTHROPIC_MODELS,
AZURE_MODELS,
BEDROCK_MODELS,
GEMINI_MODELS,
OPENAI_MODELS,
SUPPORTED_NATIVE_PROVIDERS,
SupportedModels,
SupportedNativeProviders,
)
class LLMMeta(ModelMetaclass):
@@ -49,25 +46,31 @@ class LLMMeta(ModelMetaclass):
if cls.__name__ != "LLM":
return super().__call__(*args, **kwargs)
model = kwargs.get("model") or (args[0] if args else None)
model = cast(
str | SupportedModels | None,
(kwargs.get("model") or (args[0] if args else None)),
)
is_litellm = kwargs.get("is_litellm", False)
if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string")
if args and not kwargs.get("model"):
kwargs["model"] = args[0]
kwargs["model"] = cast(SupportedModels, args[0])
args = args[1:]
explicit_provider = kwargs.get("provider")
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
if explicit_provider:
provider = explicit_provider
use_native = True
model_string = model
elif "/" in model:
prefix, _, model_part = model.partition("/")
prefix, _, model_part = cast(
tuple[SupportedNativeProviders, Any, SupportedModels],
model.partition("/"),
)
provider_mapping = {
provider_mapping: dict[str, SupportedNativeProviders] = {
"openai": "openai",
"anthropic": "anthropic",
"claude": "anthropic",
@@ -122,7 +125,9 @@ class LLMMeta(ModelMetaclass):
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
@staticmethod
def _validate_model_in_constants(model: str, provider: str) -> bool:
def _validate_model_in_constants(
model: SupportedModels, provider: SupportedNativeProviders | None
) -> bool:
"""Validate if a model name exists in the provider's constants.
Args:
@@ -132,12 +137,6 @@ class LLMMeta(ModelMetaclass):
Returns:
True if the model exists in the provider's constants, False otherwise
"""
from crewai.llm.constants import (
ANTHROPIC_MODELS,
BEDROCK_MODELS,
GEMINI_MODELS,
OPENAI_MODELS,
)
if provider == "openai":
return model in OPENAI_MODELS
@@ -158,7 +157,9 @@ class LLMMeta(ModelMetaclass):
return False
@staticmethod
def _infer_provider_from_model(model: str) -> str:
def _infer_provider_from_model(
model: SupportedModels | str,
) -> SupportedNativeProviders:
"""Infer the provider from the model name.
Args:
@@ -167,13 +168,6 @@ class LLMMeta(ModelMetaclass):
Returns:
The inferred provider name, defaults to "openai"
"""
from crewai.llm.constants import (
ANTHROPIC_MODELS,
AZURE_MODELS,
BEDROCK_MODELS,
GEMINI_MODELS,
OPENAI_MODELS,
)
if model in OPENAI_MODELS:
return "openai"
@@ -193,7 +187,7 @@ class LLMMeta(ModelMetaclass):
return "openai"
@staticmethod
def _get_native_provider(provider: str) -> type | None:
def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None:
"""Get native provider class if available.
Args:

View File

@@ -48,11 +48,6 @@ class AnthropicCompletion(BaseLLM):
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level
"""
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
base_url: str | None = Field(
default=None, description="Custom base URL for Anthropic API"
)
@@ -68,11 +63,8 @@ class AnthropicCompletion(BaseLLM):
client_params: dict[str, Any] | None = Field(
default=None, description="Additional Anthropic client parameters"
)
interceptor: Any = Field(
default=None, description="HTTP interceptor for request/response modification"
)
client: Any = Field(
default=None, exclude=True, description="Anthropic client instance"
client: Anthropic = Field(
default_factory=Anthropic, exclude=True, description="Anthropic client instance"
)
_is_claude_3: bool = PrivateAttr(default=False)

View File

@@ -91,9 +91,6 @@ class AzureCompletion(BaseLLM):
default=None, description="Maximum tokens in response"
)
stream: bool = Field(default=False, description="Enable streaming responses")
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Azure)"
)
client: Any = Field(default=None, exclude=True, description="Azure client instance")
_is_openai_model: bool = PrivateAttr(default=False)

View File

@@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM):
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
@@ -192,9 +191,6 @@ class BedrockCompletion(BaseLLM):
additional_model_response_field_paths: list[str] | None = Field(
default=None, description="Custom response field paths"
)
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
)
client: Any = Field(
default=None, exclude=True, description="Bedrock client instance"
)

View File

@@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM):
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters for Google Gen AI Client constructor
@@ -70,9 +69,6 @@ class GeminiCompletion(BaseLLM):
default_factory=dict,
description="Additional parameters for Google Gen AI Client constructor",
)
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Gemini)"
)
client: Any = Field(
default=None, exclude=True, description="Gemini client instance"
)
@@ -81,28 +77,6 @@ class GeminiCompletion(BaseLLM):
_is_gemini_1_5: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=False)
@property
def stop_sequences(self) -> list[str]:
"""Get stop sequences as a list.
This property provides access to stop sequences in Gemini's native format
while maintaining synchronization with the base class's stop attribute.
"""
if self.stop is None:
return []
if isinstance(self.stop, str):
return [self.stop]
return self.stop
@stop_sequences.setter
def stop_sequences(self, value: list[str] | str | None) -> None:
"""Set stop sequences, synchronizing with the stop attribute.
Args:
value: Stop sequences as a list, string, or None
"""
self.stop = value
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Gemini client and validate configuration."""

View File

@@ -197,7 +197,7 @@ def test_anthropic_specific_parameters():
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.client.max_retries == 5
assert llm.client.timeout == 60
@@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking():
def test_anthropic_stop_sequences_sync():
"""Test that stop and stop_sequences attributes stay synchronized."""
"""Test that stop sequences can be set and retrieved correctly."""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
# Test setting stop as a list
llm.stop = ["\nObservation:", "\nThought:"]
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
assert llm.stop == ["\nObservation:", "\nThought:"]
# Test setting stop as a string
# Test setting stop as a string - note: setting via attribute doesn't go through validator
# so it stays as a string
llm.stop = "\nFinal Answer:"
assert llm.stop_sequences == ["\nFinal Answer:"]
assert llm.stop == ["\nFinal Answer:"]
assert llm.stop == "\nFinal Answer:"
# Test setting stop as None
llm.stop = None
assert llm.stop_sequences == []
assert llm.stop == []
assert llm.stop is None
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])

View File

@@ -147,7 +147,7 @@ def test_bedrock_specific_parameters():
from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.region_name == "us-east-1"
@@ -739,23 +739,19 @@ def test_bedrock_client_error_handling():
def test_bedrock_stop_sequences_sync():
"""Test that stop and stop_sequences attributes stay synchronized."""
"""Test that stop sequences can be set and retrieved correctly."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test setting stop as a list
llm.stop = ["\nObservation:", "\nThought:"]
assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"]
assert llm.stop == ["\nObservation:", "\nThought:"]
# Test setting stop as a string
llm.stop = "\nFinal Answer:"
assert list(llm.stop_sequences) == ["\nFinal Answer:"]
assert llm.stop == ["\nFinal Answer:"]
llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:")
assert llm2.stop == ["\nFinal Answer:"]
# Test setting stop as None
llm.stop = None
assert list(llm.stop_sequences) == []
assert llm.stop == []
assert llm.stop is None
def test_bedrock_stop_sequences_sent_to_api():

View File

@@ -188,7 +188,7 @@ def test_gemini_specific_parameters():
from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.safety_settings == safety_settings
assert llm.project == "test-project"
@@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking():
def test_gemini_stop_sequences_sync():
"""Test that stop and stop_sequences attributes stay synchronized."""
"""Test that stop sequences can be set and retrieved correctly."""
llm = LLM(model="google/gemini-2.0-flash-001")
# Test setting stop as a list
llm.stop = ["\nObservation:", "\nThought:"]
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
assert llm.stop == ["\nObservation:", "\nThought:"]
# Test setting stop as a string
llm.stop = "\nFinal Answer:"
assert llm.stop_sequences == ["\nFinal Answer:"]
assert llm.stop == ["\nFinal Answer:"]
assert llm.stop == "\nFinal Answer:"
# Test setting stop as None
llm.stop = None
assert llm.stop_sequences == []
assert llm.stop == []
assert llm.stop is None
def test_gemini_stop_sequences_sent_to_api():