chore: fix attr ref

This commit is contained in:
Greyson LaLonde
2025-11-11 00:02:34 -05:00
parent 67e39073c7
commit 6fb13ee3e0
10 changed files with 98 additions and 107 deletions

View File

@@ -14,7 +14,7 @@ import re
from typing import TYPE_CHECKING, Any, ClassVar, Final from typing import TYPE_CHECKING, Any, ClassVar, Final
import httpx import httpx
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import ( from crewai.events.types.llm_events import (
@@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
model: The model identifier/name. model: The model identifier/name.
temperature: Optional temperature setting for response generation. temperature: Optional temperature setting for response generation.
stop: A list of stop sequences that the LLM should use to stop generation. stop: A list of stop sequences that the LLM should use to stop generation.
additional_params: Additional provider-specific parameters.
""" """
model_config: ClassVar[ConfigDict] = ConfigDict( model_config: ClassVar[ConfigDict] = ConfigDict(
arbitrary_types_allowed=True, extra="allow" extra="allow", populate_by_name=True
) )
# Core fields # Core fields
@@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
stop: list[str] = Field( stop: list[str] = Field(
default_factory=list, default_factory=list,
description="Stop sequences for generation", description="Stop sequences for generation",
validation_alias="stop_sequences", alias="stop_sequences",
) )
# Internal fields # Internal fields
@@ -90,7 +89,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
default=False, description="Whether this instance uses LiteLLM" default=False, description="Whether this instance uses LiteLLM"
) )
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field( interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field(
None, description="HTTP request/response interceptor" default=None, description="HTTP request/response interceptor"
) )
_token_usage: dict[str, int] = { _token_usage: dict[str, int] = {
"total_tokens": 0, "total_tokens": 0,
@@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
"cached_prompt_tokens": 0, "cached_prompt_tokens": 0,
} }
@field_validator("stop", mode="before")
@classmethod
def _normalize_stop(cls, value: Any) -> list[str]:
"""Normalize stop sequences to a list.
Args:
value: Stop sequences as string, list, or None
Returns:
Normalized list of stop sequences
"""
if value is None:
return []
if isinstance(value, str):
return [value]
if isinstance(value, list):
return value
return []
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:

View File

@@ -1,6 +1,31 @@
from typing import Literal, TypeAlias from typing import Literal, TypeAlias
SupportedNativeProviders: TypeAlias = Literal[
"openai",
"anthropic",
"claude",
"azure",
"azure_openai",
"google",
"gemini",
"bedrock",
"aws",
]
SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [
"openai",
"anthropic",
"claude",
"azure",
"azure_openai",
"google",
"gemini",
"bedrock",
"aws",
]
OpenAIModels: TypeAlias = Literal[ OpenAIModels: TypeAlias = Literal[
"gpt-3.5-turbo", "gpt-3.5-turbo",
"gpt-3.5-turbo-0125", "gpt-3.5-turbo-0125",
@@ -556,3 +581,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
"qwen.qwen3-coder-30b-a3b-v1:0", "qwen.qwen3-coder-30b-a3b-v1:0",
"twelvelabs.pegasus-1-2-v1:0", "twelvelabs.pegasus-1-2-v1:0",
] ]
SupportedModels: TypeAlias = (
OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels
)

View File

@@ -7,23 +7,20 @@ based on the model parameter at instantiation time.
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import Any, cast
from pydantic._internal._model_construction import ModelMetaclass from pydantic._internal._model_construction import ModelMetaclass
from crewai.llm.constants import (
# Provider constants imported from crewai.llm.constants ANTHROPIC_MODELS,
SUPPORTED_NATIVE_PROVIDERS: list[str] = [ AZURE_MODELS,
"openai", BEDROCK_MODELS,
"anthropic", GEMINI_MODELS,
"claude", OPENAI_MODELS,
"azure", SUPPORTED_NATIVE_PROVIDERS,
"azure_openai", SupportedModels,
"google", SupportedNativeProviders,
"gemini", )
"bedrock",
"aws",
]
class LLMMeta(ModelMetaclass): class LLMMeta(ModelMetaclass):
@@ -49,25 +46,31 @@ class LLMMeta(ModelMetaclass):
if cls.__name__ != "LLM": if cls.__name__ != "LLM":
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
model = kwargs.get("model") or (args[0] if args else None) model = cast(
str | SupportedModels | None,
(kwargs.get("model") or (args[0] if args else None)),
)
is_litellm = kwargs.get("is_litellm", False) is_litellm = kwargs.get("is_litellm", False)
if not model or not isinstance(model, str): if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string") raise ValueError("Model must be a non-empty string")
if args and not kwargs.get("model"): if args and not kwargs.get("model"):
kwargs["model"] = args[0] kwargs["model"] = cast(SupportedModels, args[0])
args = args[1:] args = args[1:]
explicit_provider = kwargs.get("provider") explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
if explicit_provider: if explicit_provider:
provider = explicit_provider provider = explicit_provider
use_native = True use_native = True
model_string = model model_string = model
elif "/" in model: elif "/" in model:
prefix, _, model_part = model.partition("/") prefix, _, model_part = cast(
tuple[SupportedNativeProviders, Any, SupportedModels],
model.partition("/"),
)
provider_mapping = { provider_mapping: dict[str, SupportedNativeProviders] = {
"openai": "openai", "openai": "openai",
"anthropic": "anthropic", "anthropic": "anthropic",
"claude": "anthropic", "claude": "anthropic",
@@ -122,7 +125,9 @@ class LLMMeta(ModelMetaclass):
return super().__call__(model=model, is_litellm=True, **kwargs_copy) return super().__call__(model=model, is_litellm=True, **kwargs_copy)
@staticmethod @staticmethod
def _validate_model_in_constants(model: str, provider: str) -> bool: def _validate_model_in_constants(
model: SupportedModels, provider: SupportedNativeProviders | None
) -> bool:
"""Validate if a model name exists in the provider's constants. """Validate if a model name exists in the provider's constants.
Args: Args:
@@ -132,12 +137,6 @@ class LLMMeta(ModelMetaclass):
Returns: Returns:
True if the model exists in the provider's constants, False otherwise True if the model exists in the provider's constants, False otherwise
""" """
from crewai.llm.constants import (
ANTHROPIC_MODELS,
BEDROCK_MODELS,
GEMINI_MODELS,
OPENAI_MODELS,
)
if provider == "openai": if provider == "openai":
return model in OPENAI_MODELS return model in OPENAI_MODELS
@@ -158,7 +157,9 @@ class LLMMeta(ModelMetaclass):
return False return False
@staticmethod @staticmethod
def _infer_provider_from_model(model: str) -> str: def _infer_provider_from_model(
model: SupportedModels | str,
) -> SupportedNativeProviders:
"""Infer the provider from the model name. """Infer the provider from the model name.
Args: Args:
@@ -167,13 +168,6 @@ class LLMMeta(ModelMetaclass):
Returns: Returns:
The inferred provider name, defaults to "openai" The inferred provider name, defaults to "openai"
""" """
from crewai.llm.constants import (
ANTHROPIC_MODELS,
AZURE_MODELS,
BEDROCK_MODELS,
GEMINI_MODELS,
OPENAI_MODELS,
)
if model in OPENAI_MODELS: if model in OPENAI_MODELS:
return "openai" return "openai"
@@ -193,7 +187,7 @@ class LLMMeta(ModelMetaclass):
return "openai" return "openai"
@staticmethod @staticmethod
def _get_native_provider(provider: str) -> type | None: def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None:
"""Get native provider class if available. """Get native provider class if available.
Args: Args:

View File

@@ -48,11 +48,6 @@ class AnthropicCompletion(BaseLLM):
client_params: Additional parameters for the Anthropic client client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level interceptor: HTTP interceptor for modifying requests/responses at transport level
""" """
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
base_url: str | None = Field( base_url: str | None = Field(
default=None, description="Custom base URL for Anthropic API" default=None, description="Custom base URL for Anthropic API"
) )
@@ -68,11 +63,8 @@ class AnthropicCompletion(BaseLLM):
client_params: dict[str, Any] | None = Field( client_params: dict[str, Any] | None = Field(
default=None, description="Additional Anthropic client parameters" default=None, description="Additional Anthropic client parameters"
) )
interceptor: Any = Field( client: Anthropic = Field(
default=None, description="HTTP interceptor for request/response modification" default_factory=Anthropic, exclude=True, description="Anthropic client instance"
)
client: Any = Field(
default=None, exclude=True, description="Anthropic client instance"
) )
_is_claude_3: bool = PrivateAttr(default=False) _is_claude_3: bool = PrivateAttr(default=False)

View File

@@ -91,9 +91,6 @@ class AzureCompletion(BaseLLM):
default=None, description="Maximum tokens in response" default=None, description="Maximum tokens in response"
) )
stream: bool = Field(default=False, description="Enable streaming responses") stream: bool = Field(default=False, description="Enable streaming responses")
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Azure)"
)
client: Any = Field(default=None, exclude=True, description="Azure client instance") client: Any = Field(default=None, exclude=True, description="Azure client instance")
_is_openai_model: bool = PrivateAttr(default=False) _is_openai_model: bool = PrivateAttr(default=False)

View File

@@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM):
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only) top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters additional_model_request_fields: Model-specific request parameters
@@ -192,9 +191,6 @@ class BedrockCompletion(BaseLLM):
additional_model_response_field_paths: list[str] | None = Field( additional_model_response_field_paths: list[str] | None = Field(
default=None, description="Custom response field paths" default=None, description="Custom response field paths"
) )
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
)
client: Any = Field( client: Any = Field(
default=None, exclude=True, description="Bedrock client instance" default=None, exclude=True, description="Bedrock client instance"
) )

View File

@@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM):
top_p: Nucleus sampling parameter top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses stream: Enable streaming responses
safety_settings: Safety filter settings safety_settings: Safety filter settings
client_params: Additional parameters for Google Gen AI Client constructor client_params: Additional parameters for Google Gen AI Client constructor
@@ -70,9 +69,6 @@ class GeminiCompletion(BaseLLM):
default_factory=dict, default_factory=dict,
description="Additional parameters for Google Gen AI Client constructor", description="Additional parameters for Google Gen AI Client constructor",
) )
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Gemini)"
)
client: Any = Field( client: Any = Field(
default=None, exclude=True, description="Gemini client instance" default=None, exclude=True, description="Gemini client instance"
) )
@@ -81,28 +77,6 @@ class GeminiCompletion(BaseLLM):
_is_gemini_1_5: bool = PrivateAttr(default=False) _is_gemini_1_5: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=False) _supports_tools: bool = PrivateAttr(default=False)
@property
def stop_sequences(self) -> list[str]:
"""Get stop sequences as a list.
This property provides access to stop sequences in Gemini's native format
while maintaining synchronization with the base class's stop attribute.
"""
if self.stop is None:
return []
if isinstance(self.stop, str):
return [self.stop]
return self.stop
@stop_sequences.setter
def stop_sequences(self, value: list[str] | str | None) -> None:
"""Set stop sequences, synchronizing with the stop attribute.
Args:
value: Stop sequences as a list, string, or None
"""
self.stop = value
@model_validator(mode="after") @model_validator(mode="after")
def setup_client(self) -> Self: def setup_client(self) -> Self:
"""Initialize the Gemini client and validate configuration.""" """Initialize the Gemini client and validate configuration."""

View File

@@ -197,7 +197,7 @@ def test_anthropic_specific_parameters():
from crewai.llm.providers.anthropic.completion import AnthropicCompletion from crewai.llm.providers.anthropic.completion import AnthropicCompletion
assert isinstance(llm, AnthropicCompletion) assert isinstance(llm, AnthropicCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"] assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True assert llm.stream == True
assert llm.client.max_retries == 5 assert llm.client.max_retries == 5
assert llm.client.timeout == 60 assert llm.client.timeout == 60
@@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking():
def test_anthropic_stop_sequences_sync(): def test_anthropic_stop_sequences_sync():
"""Test that stop and stop_sequences attributes stay synchronized.""" """Test that stop sequences can be set and retrieved correctly."""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
# Test setting stop as a list # Test setting stop as a list
llm.stop = ["\nObservation:", "\nThought:"] llm.stop = ["\nObservation:", "\nThought:"]
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
assert llm.stop == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"]
# Test setting stop as a string # Test setting stop as a string - note: setting via attribute doesn't go through validator
# so it stays as a string
llm.stop = "\nFinal Answer:" llm.stop = "\nFinal Answer:"
assert llm.stop_sequences == ["\nFinal Answer:"] assert llm.stop == "\nFinal Answer:"
assert llm.stop == ["\nFinal Answer:"]
# Test setting stop as None # Test setting stop as None
llm.stop = None llm.stop = None
assert llm.stop_sequences == [] assert llm.stop is None
assert llm.stop == []
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) @pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])

View File

@@ -147,7 +147,7 @@ def test_bedrock_specific_parameters():
from crewai.llm.providers.bedrock.completion import BedrockCompletion from crewai.llm.providers.bedrock.completion import BedrockCompletion
assert isinstance(llm, BedrockCompletion) assert isinstance(llm, BedrockCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"] assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True assert llm.stream == True
assert llm.region_name == "us-east-1" assert llm.region_name == "us-east-1"
@@ -739,23 +739,19 @@ def test_bedrock_client_error_handling():
def test_bedrock_stop_sequences_sync(): def test_bedrock_stop_sequences_sync():
"""Test that stop and stop_sequences attributes stay synchronized.""" """Test that stop sequences can be set and retrieved correctly."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test setting stop as a list # Test setting stop as a list
llm.stop = ["\nObservation:", "\nThought:"] llm.stop = ["\nObservation:", "\nThought:"]
assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"]
assert llm.stop == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"]
# Test setting stop as a string llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:")
llm.stop = "\nFinal Answer:" assert llm2.stop == ["\nFinal Answer:"]
assert list(llm.stop_sequences) == ["\nFinal Answer:"]
assert llm.stop == ["\nFinal Answer:"]
# Test setting stop as None # Test setting stop as None
llm.stop = None llm.stop = None
assert list(llm.stop_sequences) == [] assert llm.stop is None
assert llm.stop == []
def test_bedrock_stop_sequences_sent_to_api(): def test_bedrock_stop_sequences_sent_to_api():

View File

@@ -188,7 +188,7 @@ def test_gemini_specific_parameters():
from crewai.llm.providers.gemini.completion import GeminiCompletion from crewai.llm.providers.gemini.completion import GeminiCompletion
assert isinstance(llm, GeminiCompletion) assert isinstance(llm, GeminiCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"] assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True assert llm.stream == True
assert llm.safety_settings == safety_settings assert llm.safety_settings == safety_settings
assert llm.project == "test-project" assert llm.project == "test-project"
@@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking():
def test_gemini_stop_sequences_sync(): def test_gemini_stop_sequences_sync():
"""Test that stop and stop_sequences attributes stay synchronized.""" """Test that stop sequences can be set and retrieved correctly."""
llm = LLM(model="google/gemini-2.0-flash-001") llm = LLM(model="google/gemini-2.0-flash-001")
# Test setting stop as a list # Test setting stop as a list
llm.stop = ["\nObservation:", "\nThought:"] llm.stop = ["\nObservation:", "\nThought:"]
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
assert llm.stop == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"]
# Test setting stop as a string # Test setting stop as a string
llm.stop = "\nFinal Answer:" llm.stop = "\nFinal Answer:"
assert llm.stop_sequences == ["\nFinal Answer:"] assert llm.stop == "\nFinal Answer:"
assert llm.stop == ["\nFinal Answer:"]
# Test setting stop as None # Test setting stop as None
llm.stop = None llm.stop = None
assert llm.stop_sequences == [] assert llm.stop is None
assert llm.stop == []
def test_gemini_stop_sequences_sent_to_api(): def test_gemini_stop_sequences_sent_to_api():