diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index f60ce500e..e86fd817b 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -14,7 +14,7 @@ import re from typing import TYPE_CHECKING, Any, ClassVar, Final import httpx -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): model: The model identifier/name. temperature: Optional temperature setting for response generation. stop: A list of stop sequences that the LLM should use to stop generation. - additional_params: Additional provider-specific parameters. """ model_config: ClassVar[ConfigDict] = ConfigDict( - arbitrary_types_allowed=True, extra="allow" + extra="allow", populate_by_name=True ) # Core fields @@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): stop: list[str] = Field( default_factory=list, description="Stop sequences for generation", - validation_alias="stop_sequences", + alias="stop_sequences", ) # Internal fields @@ -90,7 +89,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): default=False, description="Whether this instance uses LiteLLM" ) interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field( - None, description="HTTP request/response interceptor" + default=None, description="HTTP request/response interceptor" ) _token_usage: dict[str, int] = { "total_tokens": 0, @@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): "cached_prompt_tokens": 0, } + @field_validator("stop", mode="before") + @classmethod + def _normalize_stop(cls, value: Any) -> list[str]: + """Normalize stop sequences to a list. + + Args: + value: Stop sequences as string, list, or None + + Returns: + Normalized list of stop sequences + """ + if value is None: + return [] + if isinstance(value, str): + return [value] + if isinstance(value, list): + return value + return [] + @model_validator(mode="before") @classmethod def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: diff --git a/lib/crewai/src/crewai/llm/constants.py b/lib/crewai/src/crewai/llm/constants.py index 2765a9458..880e14e9c 100644 --- a/lib/crewai/src/crewai/llm/constants.py +++ b/lib/crewai/src/crewai/llm/constants.py @@ -1,6 +1,31 @@ from typing import Literal, TypeAlias +SupportedNativeProviders: TypeAlias = Literal[ + "openai", + "anthropic", + "claude", + "azure", + "azure_openai", + "google", + "gemini", + "bedrock", + "aws", +] + +SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [ + "openai", + "anthropic", + "claude", + "azure", + "azure_openai", + "google", + "gemini", + "bedrock", + "aws", +] + + OpenAIModels: TypeAlias = Literal[ "gpt-3.5-turbo", "gpt-3.5-turbo-0125", @@ -556,3 +581,7 @@ BEDROCK_MODELS: list[BedrockModels] = [ "qwen.qwen3-coder-30b-a3b-v1:0", "twelvelabs.pegasus-1-2-v1:0", ] + +SupportedModels: TypeAlias = ( + OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels +) diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py index f91fad96d..f210ab742 100644 --- a/lib/crewai/src/crewai/llm/internal/meta.py +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -7,23 +7,20 @@ based on the model parameter at instantiation time. from __future__ import annotations import logging -from typing import Any +from typing import Any, cast from pydantic._internal._model_construction import ModelMetaclass - -# Provider constants imported from crewai.llm.constants -SUPPORTED_NATIVE_PROVIDERS: list[str] = [ - "openai", - "anthropic", - "claude", - "azure", - "azure_openai", - "google", - "gemini", - "bedrock", - "aws", -] +from crewai.llm.constants import ( + ANTHROPIC_MODELS, + AZURE_MODELS, + BEDROCK_MODELS, + GEMINI_MODELS, + OPENAI_MODELS, + SUPPORTED_NATIVE_PROVIDERS, + SupportedModels, + SupportedNativeProviders, +) class LLMMeta(ModelMetaclass): @@ -49,25 +46,31 @@ class LLMMeta(ModelMetaclass): if cls.__name__ != "LLM": return super().__call__(*args, **kwargs) - model = kwargs.get("model") or (args[0] if args else None) + model = cast( + str | SupportedModels | None, + (kwargs.get("model") or (args[0] if args else None)), + ) is_litellm = kwargs.get("is_litellm", False) if not model or not isinstance(model, str): raise ValueError("Model must be a non-empty string") if args and not kwargs.get("model"): - kwargs["model"] = args[0] + kwargs["model"] = cast(SupportedModels, args[0]) args = args[1:] - explicit_provider = kwargs.get("provider") + explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider")) if explicit_provider: provider = explicit_provider use_native = True model_string = model elif "/" in model: - prefix, _, model_part = model.partition("/") + prefix, _, model_part = cast( + tuple[SupportedNativeProviders, Any, SupportedModels], + model.partition("/"), + ) - provider_mapping = { + provider_mapping: dict[str, SupportedNativeProviders] = { "openai": "openai", "anthropic": "anthropic", "claude": "anthropic", @@ -122,7 +125,9 @@ class LLMMeta(ModelMetaclass): return super().__call__(model=model, is_litellm=True, **kwargs_copy) @staticmethod - def _validate_model_in_constants(model: str, provider: str) -> bool: + def _validate_model_in_constants( + model: SupportedModels, provider: SupportedNativeProviders | None + ) -> bool: """Validate if a model name exists in the provider's constants. Args: @@ -132,12 +137,6 @@ class LLMMeta(ModelMetaclass): Returns: True if the model exists in the provider's constants, False otherwise """ - from crewai.llm.constants import ( - ANTHROPIC_MODELS, - BEDROCK_MODELS, - GEMINI_MODELS, - OPENAI_MODELS, - ) if provider == "openai": return model in OPENAI_MODELS @@ -158,7 +157,9 @@ class LLMMeta(ModelMetaclass): return False @staticmethod - def _infer_provider_from_model(model: str) -> str: + def _infer_provider_from_model( + model: SupportedModels | str, + ) -> SupportedNativeProviders: """Infer the provider from the model name. Args: @@ -167,13 +168,6 @@ class LLMMeta(ModelMetaclass): Returns: The inferred provider name, defaults to "openai" """ - from crewai.llm.constants import ( - ANTHROPIC_MODELS, - AZURE_MODELS, - BEDROCK_MODELS, - GEMINI_MODELS, - OPENAI_MODELS, - ) if model in OPENAI_MODELS: return "openai" @@ -193,7 +187,7 @@ class LLMMeta(ModelMetaclass): return "openai" @staticmethod - def _get_native_provider(provider: str) -> type | None: + def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None: """Get native provider class if available. Args: diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index dd13c0f5e..05366bbf4 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -48,11 +48,6 @@ class AnthropicCompletion(BaseLLM): client_params: Additional parameters for the Anthropic client interceptor: HTTP interceptor for modifying requests/responses at transport level """ - - model_config: ClassVar[ConfigDict] = ConfigDict( - ignored_types=(property,), arbitrary_types_allowed=True - ) - base_url: str | None = Field( default=None, description="Custom base URL for Anthropic API" ) @@ -68,11 +63,8 @@ class AnthropicCompletion(BaseLLM): client_params: dict[str, Any] | None = Field( default=None, description="Additional Anthropic client parameters" ) - interceptor: Any = Field( - default=None, description="HTTP interceptor for request/response modification" - ) - client: Any = Field( - default=None, exclude=True, description="Anthropic client instance" + client: Anthropic = Field( + default_factory=Anthropic, exclude=True, description="Anthropic client instance" ) _is_claude_3: bool = PrivateAttr(default=False) diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index 6076b9f4d..3a4f68d08 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -91,9 +91,6 @@ class AzureCompletion(BaseLLM): default=None, description="Maximum tokens in response" ) stream: bool = Field(default=False, description="Enable streaming responses") - interceptor: Any = Field( - default=None, description="HTTP interceptor (not yet supported for Azure)" - ) client: Any = Field(default=None, exclude=True, description="Azure client instance") _is_openai_model: bool = PrivateAttr(default=False) diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index ed6738c4b..58495a151 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM): max_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter top_k: Top-k sampling parameter (Claude models only) - stop_sequences: List of sequences that stop generation stream: Whether to use streaming responses guardrail_config: Guardrail configuration for content filtering additional_model_request_fields: Model-specific request parameters @@ -192,9 +191,6 @@ class BedrockCompletion(BaseLLM): additional_model_response_field_paths: list[str] | None = Field( default=None, description="Custom response field paths" ) - interceptor: Any = Field( - default=None, description="HTTP interceptor (not yet supported for Bedrock)" - ) client: Any = Field( default=None, exclude=True, description="Bedrock client instance" ) diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index 34bff2508..f2b191656 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM): top_p: Nucleus sampling parameter top_k: Top-k sampling parameter max_output_tokens: Maximum tokens in response - stop_sequences: Stop sequences stream: Enable streaming responses safety_settings: Safety filter settings client_params: Additional parameters for Google Gen AI Client constructor @@ -70,9 +69,6 @@ class GeminiCompletion(BaseLLM): default_factory=dict, description="Additional parameters for Google Gen AI Client constructor", ) - interceptor: Any = Field( - default=None, description="HTTP interceptor (not yet supported for Gemini)" - ) client: Any = Field( default=None, exclude=True, description="Gemini client instance" ) @@ -81,28 +77,6 @@ class GeminiCompletion(BaseLLM): _is_gemini_1_5: bool = PrivateAttr(default=False) _supports_tools: bool = PrivateAttr(default=False) - @property - def stop_sequences(self) -> list[str]: - """Get stop sequences as a list. - - This property provides access to stop sequences in Gemini's native format - while maintaining synchronization with the base class's stop attribute. - """ - if self.stop is None: - return [] - if isinstance(self.stop, str): - return [self.stop] - return self.stop - - @stop_sequences.setter - def stop_sequences(self, value: list[str] | str | None) -> None: - """Set stop sequences, synchronizing with the stop attribute. - - Args: - value: Stop sequences as a list, string, or None - """ - self.stop = value - @model_validator(mode="after") def setup_client(self) -> Self: """Initialize the Gemini client and validate configuration.""" diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index c0d2957f1..5a91b2e1e 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -197,7 +197,7 @@ def test_anthropic_specific_parameters(): from crewai.llm.providers.anthropic.completion import AnthropicCompletion assert isinstance(llm, AnthropicCompletion) - assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True assert llm.client.max_retries == 5 assert llm.client.timeout == 60 @@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking(): def test_anthropic_stop_sequences_sync(): - """Test that stop and stop_sequences attributes stay synchronized.""" + """Test that stop sequences can be set and retrieved correctly.""" llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") # Test setting stop as a list llm.stop = ["\nObservation:", "\nThought:"] - assert llm.stop_sequences == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"] - # Test setting stop as a string + # Test setting stop as a string - note: setting via attribute doesn't go through validator + # so it stays as a string llm.stop = "\nFinal Answer:" - assert llm.stop_sequences == ["\nFinal Answer:"] - assert llm.stop == ["\nFinal Answer:"] + assert llm.stop == "\nFinal Answer:" # Test setting stop as None llm.stop = None - assert llm.stop_sequences == [] - assert llm.stop == [] + assert llm.stop is None @pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index 7ad7c2080..b3c12cdc2 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -147,7 +147,7 @@ def test_bedrock_specific_parameters(): from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion) - assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True assert llm.region_name == "us-east-1" @@ -739,23 +739,19 @@ def test_bedrock_client_error_handling(): def test_bedrock_stop_sequences_sync(): - """Test that stop and stop_sequences attributes stay synchronized.""" + """Test that stop sequences can be set and retrieved correctly.""" llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Test setting stop as a list llm.stop = ["\nObservation:", "\nThought:"] - assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"] - # Test setting stop as a string - llm.stop = "\nFinal Answer:" - assert list(llm.stop_sequences) == ["\nFinal Answer:"] - assert llm.stop == ["\nFinal Answer:"] + llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:") + assert llm2.stop == ["\nFinal Answer:"] # Test setting stop as None llm.stop = None - assert list(llm.stop_sequences) == [] - assert llm.stop == [] + assert llm.stop is None def test_bedrock_stop_sequences_sent_to_api(): diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index ffb070b5e..11af3e83b 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -188,7 +188,7 @@ def test_gemini_specific_parameters(): from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) - assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True assert llm.safety_settings == safety_settings assert llm.project == "test-project" @@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking(): def test_gemini_stop_sequences_sync(): - """Test that stop and stop_sequences attributes stay synchronized.""" + """Test that stop sequences can be set and retrieved correctly.""" llm = LLM(model="google/gemini-2.0-flash-001") # Test setting stop as a list llm.stop = ["\nObservation:", "\nThought:"] - assert llm.stop_sequences == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"] # Test setting stop as a string llm.stop = "\nFinal Answer:" - assert llm.stop_sequences == ["\nFinal Answer:"] - assert llm.stop == ["\nFinal Answer:"] + assert llm.stop == "\nFinal Answer:" # Test setting stop as None llm.stop = None - assert llm.stop_sequences == [] - assert llm.stop == [] + assert llm.stop is None def test_gemini_stop_sequences_sent_to_api():