mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
chore: fix attr ref
This commit is contained in:
@@ -14,7 +14,7 @@ import re
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
arbitrary_types_allowed=True, extra="allow"
|
||||
extra="allow", populate_by_name=True
|
||||
)
|
||||
|
||||
# Core fields
|
||||
@@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
stop: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Stop sequences for generation",
|
||||
validation_alias="stop_sequences",
|
||||
alias="stop_sequences",
|
||||
)
|
||||
|
||||
# Internal fields
|
||||
@@ -90,7 +89,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
default=False, description="Whether this instance uses LiteLLM"
|
||||
)
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field(
|
||||
None, description="HTTP request/response interceptor"
|
||||
default=None, description="HTTP request/response interceptor"
|
||||
)
|
||||
_token_usage: dict[str, int] = {
|
||||
"total_tokens": 0,
|
||||
@@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
@field_validator("stop", mode="before")
|
||||
@classmethod
|
||||
def _normalize_stop(cls, value: Any) -> list[str]:
|
||||
"""Normalize stop sequences to a list.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as string, list, or None
|
||||
|
||||
Returns:
|
||||
Normalized list of stop sequences
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return []
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -1,6 +1,31 @@
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
|
||||
SupportedNativeProviders: TypeAlias = Literal[
|
||||
"openai",
|
||||
"anthropic",
|
||||
"claude",
|
||||
"azure",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
]
|
||||
|
||||
SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"claude",
|
||||
"azure",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
]
|
||||
|
||||
|
||||
OpenAIModels: TypeAlias = Literal[
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
@@ -556,3 +581,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"qwen.qwen3-coder-30b-a3b-v1:0",
|
||||
"twelvelabs.pegasus-1-2-v1:0",
|
||||
]
|
||||
|
||||
SupportedModels: TypeAlias = (
|
||||
OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels
|
||||
)
|
||||
|
||||
@@ -7,23 +7,20 @@ based on the model parameter at instantiation time.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
|
||||
|
||||
# Provider constants imported from crewai.llm.constants
|
||||
SUPPORTED_NATIVE_PROVIDERS: list[str] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"claude",
|
||||
"azure",
|
||||
"azure_openai",
|
||||
"google",
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
]
|
||||
from crewai.llm.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
AZURE_MODELS,
|
||||
BEDROCK_MODELS,
|
||||
GEMINI_MODELS,
|
||||
OPENAI_MODELS,
|
||||
SUPPORTED_NATIVE_PROVIDERS,
|
||||
SupportedModels,
|
||||
SupportedNativeProviders,
|
||||
)
|
||||
|
||||
|
||||
class LLMMeta(ModelMetaclass):
|
||||
@@ -49,25 +46,31 @@ class LLMMeta(ModelMetaclass):
|
||||
if cls.__name__ != "LLM":
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
model = kwargs.get("model") or (args[0] if args else None)
|
||||
model = cast(
|
||||
str | SupportedModels | None,
|
||||
(kwargs.get("model") or (args[0] if args else None)),
|
||||
)
|
||||
is_litellm = kwargs.get("is_litellm", False)
|
||||
|
||||
if not model or not isinstance(model, str):
|
||||
raise ValueError("Model must be a non-empty string")
|
||||
|
||||
if args and not kwargs.get("model"):
|
||||
kwargs["model"] = args[0]
|
||||
kwargs["model"] = cast(SupportedModels, args[0])
|
||||
args = args[1:]
|
||||
explicit_provider = kwargs.get("provider")
|
||||
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
|
||||
|
||||
if explicit_provider:
|
||||
provider = explicit_provider
|
||||
use_native = True
|
||||
model_string = model
|
||||
elif "/" in model:
|
||||
prefix, _, model_part = model.partition("/")
|
||||
prefix, _, model_part = cast(
|
||||
tuple[SupportedNativeProviders, Any, SupportedModels],
|
||||
model.partition("/"),
|
||||
)
|
||||
|
||||
provider_mapping = {
|
||||
provider_mapping: dict[str, SupportedNativeProviders] = {
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"claude": "anthropic",
|
||||
@@ -122,7 +125,9 @@ class LLMMeta(ModelMetaclass):
|
||||
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
|
||||
|
||||
@staticmethod
|
||||
def _validate_model_in_constants(model: str, provider: str) -> bool:
|
||||
def _validate_model_in_constants(
|
||||
model: SupportedModels, provider: SupportedNativeProviders | None
|
||||
) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
|
||||
Args:
|
||||
@@ -132,12 +137,6 @@ class LLMMeta(ModelMetaclass):
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
"""
|
||||
from crewai.llm.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
BEDROCK_MODELS,
|
||||
GEMINI_MODELS,
|
||||
OPENAI_MODELS,
|
||||
)
|
||||
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
@@ -158,7 +157,9 @@ class LLMMeta(ModelMetaclass):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _infer_provider_from_model(model: str) -> str:
|
||||
def _infer_provider_from_model(
|
||||
model: SupportedModels | str,
|
||||
) -> SupportedNativeProviders:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
Args:
|
||||
@@ -167,13 +168,6 @@ class LLMMeta(ModelMetaclass):
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
from crewai.llm.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
AZURE_MODELS,
|
||||
BEDROCK_MODELS,
|
||||
GEMINI_MODELS,
|
||||
OPENAI_MODELS,
|
||||
)
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
@@ -193,7 +187,7 @@ class LLMMeta(ModelMetaclass):
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _get_native_provider(provider: str) -> type | None:
|
||||
def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None:
|
||||
"""Get native provider class if available.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -48,11 +48,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
ignored_types=(property,), arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
base_url: str | None = Field(
|
||||
default=None, description="Custom base URL for Anthropic API"
|
||||
)
|
||||
@@ -68,11 +63,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
client_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional Anthropic client parameters"
|
||||
)
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor for request/response modification"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Anthropic client instance"
|
||||
client: Anthropic = Field(
|
||||
default_factory=Anthropic, exclude=True, description="Anthropic client instance"
|
||||
)
|
||||
|
||||
_is_claude_3: bool = PrivateAttr(default=False)
|
||||
|
||||
@@ -91,9 +91,6 @@ class AzureCompletion(BaseLLM):
|
||||
default=None, description="Maximum tokens in response"
|
||||
)
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor (not yet supported for Azure)"
|
||||
)
|
||||
client: Any = Field(default=None, exclude=True, description="Azure client instance")
|
||||
|
||||
_is_openai_model: bool = PrivateAttr(default=False)
|
||||
|
||||
@@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM):
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
@@ -192,9 +191,6 @@ class BedrockCompletion(BaseLLM):
|
||||
additional_model_response_field_paths: list[str] | None = Field(
|
||||
default=None, description="Custom response field paths"
|
||||
)
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Bedrock client instance"
|
||||
)
|
||||
|
||||
@@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM):
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters for Google Gen AI Client constructor
|
||||
@@ -70,9 +69,6 @@ class GeminiCompletion(BaseLLM):
|
||||
default_factory=dict,
|
||||
description="Additional parameters for Google Gen AI Client constructor",
|
||||
)
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor (not yet supported for Gemini)"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Gemini client instance"
|
||||
)
|
||||
@@ -81,28 +77,6 @@ class GeminiCompletion(BaseLLM):
|
||||
_is_gemini_1_5: bool = PrivateAttr(default=False)
|
||||
_supports_tools: bool = PrivateAttr(default=False)
|
||||
|
||||
@property
|
||||
def stop_sequences(self) -> list[str]:
|
||||
"""Get stop sequences as a list.
|
||||
|
||||
This property provides access to stop sequences in Gemini's native format
|
||||
while maintaining synchronization with the base class's stop attribute.
|
||||
"""
|
||||
if self.stop is None:
|
||||
return []
|
||||
if isinstance(self.stop, str):
|
||||
return [self.stop]
|
||||
return self.stop
|
||||
|
||||
@stop_sequences.setter
|
||||
def stop_sequences(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences, synchronizing with the stop attribute.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, string, or None
|
||||
"""
|
||||
self.stop = value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize the Gemini client and validate configuration."""
|
||||
|
||||
@@ -197,7 +197,7 @@ def test_anthropic_specific_parameters():
|
||||
|
||||
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stop == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.client.max_retries == 5
|
||||
assert llm.client.timeout == 60
|
||||
@@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking():
|
||||
|
||||
|
||||
def test_anthropic_stop_sequences_sync():
|
||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||
"""Test that stop sequences can be set and retrieved correctly."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Test setting stop as a list
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Test setting stop as a string
|
||||
# Test setting stop as a string - note: setting via attribute doesn't go through validator
|
||||
# so it stays as a string
|
||||
llm.stop = "\nFinal Answer:"
|
||||
assert llm.stop_sequences == ["\nFinal Answer:"]
|
||||
assert llm.stop == ["\nFinal Answer:"]
|
||||
assert llm.stop == "\nFinal Answer:"
|
||||
|
||||
# Test setting stop as None
|
||||
llm.stop = None
|
||||
assert llm.stop_sequences == []
|
||||
assert llm.stop == []
|
||||
assert llm.stop is None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_bedrock_specific_parameters():
|
||||
|
||||
from crewai.llm.providers.bedrock.completion import BedrockCompletion
|
||||
assert isinstance(llm, BedrockCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stop == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.region_name == "us-east-1"
|
||||
|
||||
@@ -739,23 +739,19 @@ def test_bedrock_client_error_handling():
|
||||
|
||||
|
||||
def test_bedrock_stop_sequences_sync():
|
||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||
"""Test that stop sequences can be set and retrieved correctly."""
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test setting stop as a list
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Test setting stop as a string
|
||||
llm.stop = "\nFinal Answer:"
|
||||
assert list(llm.stop_sequences) == ["\nFinal Answer:"]
|
||||
assert llm.stop == ["\nFinal Answer:"]
|
||||
llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:")
|
||||
assert llm2.stop == ["\nFinal Answer:"]
|
||||
|
||||
# Test setting stop as None
|
||||
llm.stop = None
|
||||
assert list(llm.stop_sequences) == []
|
||||
assert llm.stop == []
|
||||
assert llm.stop is None
|
||||
|
||||
|
||||
def test_bedrock_stop_sequences_sent_to_api():
|
||||
|
||||
@@ -188,7 +188,7 @@ def test_gemini_specific_parameters():
|
||||
|
||||
from crewai.llm.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stop == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.safety_settings == safety_settings
|
||||
assert llm.project == "test-project"
|
||||
@@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking():
|
||||
|
||||
|
||||
def test_gemini_stop_sequences_sync():
|
||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
||||
"""Test that stop sequences can be set and retrieved correctly."""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Test setting stop as a list
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
|
||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Test setting stop as a string
|
||||
llm.stop = "\nFinal Answer:"
|
||||
assert llm.stop_sequences == ["\nFinal Answer:"]
|
||||
assert llm.stop == ["\nFinal Answer:"]
|
||||
assert llm.stop == "\nFinal Answer:"
|
||||
|
||||
# Test setting stop as None
|
||||
llm.stop = None
|
||||
assert llm.stop_sequences == []
|
||||
assert llm.stop == []
|
||||
assert llm.stop is None
|
||||
|
||||
|
||||
def test_gemini_stop_sequences_sent_to_api():
|
||||
|
||||
Reference in New Issue
Block a user