mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
chore: fix attr ref
This commit is contained in:
@@ -14,7 +14,7 @@ import re
|
|||||||
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.llm_events import (
|
from crewai.events.types.llm_events import (
|
||||||
@@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
|||||||
model: The model identifier/name.
|
model: The model identifier/name.
|
||||||
temperature: Optional temperature setting for response generation.
|
temperature: Optional temperature setting for response generation.
|
||||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||||
additional_params: Additional provider-specific parameters.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||||
arbitrary_types_allowed=True, extra="allow"
|
extra="allow", populate_by_name=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Core fields
|
# Core fields
|
||||||
@@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
|||||||
stop: list[str] = Field(
|
stop: list[str] = Field(
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Stop sequences for generation",
|
description="Stop sequences for generation",
|
||||||
validation_alias="stop_sequences",
|
alias="stop_sequences",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Internal fields
|
# Internal fields
|
||||||
@@ -90,7 +89,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
|||||||
default=False, description="Whether this instance uses LiteLLM"
|
default=False, description="Whether this instance uses LiteLLM"
|
||||||
)
|
)
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field(
|
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field(
|
||||||
None, description="HTTP request/response interceptor"
|
default=None, description="HTTP request/response interceptor"
|
||||||
)
|
)
|
||||||
_token_usage: dict[str, int] = {
|
_token_usage: dict[str, int] = {
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
@@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
|||||||
"cached_prompt_tokens": 0,
|
"cached_prompt_tokens": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@field_validator("stop", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_stop(cls, value: Any) -> list[str]:
|
||||||
|
"""Normalize stop sequences to a list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Stop sequences as string, list, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized list of stop sequences
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [value]
|
||||||
|
if isinstance(value, list):
|
||||||
|
return value
|
||||||
|
return []
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:
|
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -1,6 +1,31 @@
|
|||||||
from typing import Literal, TypeAlias
|
from typing import Literal, TypeAlias
|
||||||
|
|
||||||
|
|
||||||
|
SupportedNativeProviders: TypeAlias = Literal[
|
||||||
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"claude",
|
||||||
|
"azure",
|
||||||
|
"azure_openai",
|
||||||
|
"google",
|
||||||
|
"gemini",
|
||||||
|
"bedrock",
|
||||||
|
"aws",
|
||||||
|
]
|
||||||
|
|
||||||
|
SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [
|
||||||
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"claude",
|
||||||
|
"azure",
|
||||||
|
"azure_openai",
|
||||||
|
"google",
|
||||||
|
"gemini",
|
||||||
|
"bedrock",
|
||||||
|
"aws",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
OpenAIModels: TypeAlias = Literal[
|
OpenAIModels: TypeAlias = Literal[
|
||||||
"gpt-3.5-turbo",
|
"gpt-3.5-turbo",
|
||||||
"gpt-3.5-turbo-0125",
|
"gpt-3.5-turbo-0125",
|
||||||
@@ -556,3 +581,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
|||||||
"qwen.qwen3-coder-30b-a3b-v1:0",
|
"qwen.qwen3-coder-30b-a3b-v1:0",
|
||||||
"twelvelabs.pegasus-1-2-v1:0",
|
"twelvelabs.pegasus-1-2-v1:0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SupportedModels: TypeAlias = (
|
||||||
|
OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,23 +7,20 @@ based on the model parameter at instantiation time.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
from pydantic._internal._model_construction import ModelMetaclass
|
from pydantic._internal._model_construction import ModelMetaclass
|
||||||
|
|
||||||
|
from crewai.llm.constants import (
|
||||||
# Provider constants imported from crewai.llm.constants
|
ANTHROPIC_MODELS,
|
||||||
SUPPORTED_NATIVE_PROVIDERS: list[str] = [
|
AZURE_MODELS,
|
||||||
"openai",
|
BEDROCK_MODELS,
|
||||||
"anthropic",
|
GEMINI_MODELS,
|
||||||
"claude",
|
OPENAI_MODELS,
|
||||||
"azure",
|
SUPPORTED_NATIVE_PROVIDERS,
|
||||||
"azure_openai",
|
SupportedModels,
|
||||||
"google",
|
SupportedNativeProviders,
|
||||||
"gemini",
|
)
|
||||||
"bedrock",
|
|
||||||
"aws",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class LLMMeta(ModelMetaclass):
|
class LLMMeta(ModelMetaclass):
|
||||||
@@ -49,25 +46,31 @@ class LLMMeta(ModelMetaclass):
|
|||||||
if cls.__name__ != "LLM":
|
if cls.__name__ != "LLM":
|
||||||
return super().__call__(*args, **kwargs)
|
return super().__call__(*args, **kwargs)
|
||||||
|
|
||||||
model = kwargs.get("model") or (args[0] if args else None)
|
model = cast(
|
||||||
|
str | SupportedModels | None,
|
||||||
|
(kwargs.get("model") or (args[0] if args else None)),
|
||||||
|
)
|
||||||
is_litellm = kwargs.get("is_litellm", False)
|
is_litellm = kwargs.get("is_litellm", False)
|
||||||
|
|
||||||
if not model or not isinstance(model, str):
|
if not model or not isinstance(model, str):
|
||||||
raise ValueError("Model must be a non-empty string")
|
raise ValueError("Model must be a non-empty string")
|
||||||
|
|
||||||
if args and not kwargs.get("model"):
|
if args and not kwargs.get("model"):
|
||||||
kwargs["model"] = args[0]
|
kwargs["model"] = cast(SupportedModels, args[0])
|
||||||
args = args[1:]
|
args = args[1:]
|
||||||
explicit_provider = kwargs.get("provider")
|
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
|
||||||
|
|
||||||
if explicit_provider:
|
if explicit_provider:
|
||||||
provider = explicit_provider
|
provider = explicit_provider
|
||||||
use_native = True
|
use_native = True
|
||||||
model_string = model
|
model_string = model
|
||||||
elif "/" in model:
|
elif "/" in model:
|
||||||
prefix, _, model_part = model.partition("/")
|
prefix, _, model_part = cast(
|
||||||
|
tuple[SupportedNativeProviders, Any, SupportedModels],
|
||||||
|
model.partition("/"),
|
||||||
|
)
|
||||||
|
|
||||||
provider_mapping = {
|
provider_mapping: dict[str, SupportedNativeProviders] = {
|
||||||
"openai": "openai",
|
"openai": "openai",
|
||||||
"anthropic": "anthropic",
|
"anthropic": "anthropic",
|
||||||
"claude": "anthropic",
|
"claude": "anthropic",
|
||||||
@@ -122,7 +125,9 @@ class LLMMeta(ModelMetaclass):
|
|||||||
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
|
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_model_in_constants(model: str, provider: str) -> bool:
|
def _validate_model_in_constants(
|
||||||
|
model: SupportedModels, provider: SupportedNativeProviders | None
|
||||||
|
) -> bool:
|
||||||
"""Validate if a model name exists in the provider's constants.
|
"""Validate if a model name exists in the provider's constants.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -132,12 +137,6 @@ class LLMMeta(ModelMetaclass):
|
|||||||
Returns:
|
Returns:
|
||||||
True if the model exists in the provider's constants, False otherwise
|
True if the model exists in the provider's constants, False otherwise
|
||||||
"""
|
"""
|
||||||
from crewai.llm.constants import (
|
|
||||||
ANTHROPIC_MODELS,
|
|
||||||
BEDROCK_MODELS,
|
|
||||||
GEMINI_MODELS,
|
|
||||||
OPENAI_MODELS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
return model in OPENAI_MODELS
|
return model in OPENAI_MODELS
|
||||||
@@ -158,7 +157,9 @@ class LLMMeta(ModelMetaclass):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _infer_provider_from_model(model: str) -> str:
|
def _infer_provider_from_model(
|
||||||
|
model: SupportedModels | str,
|
||||||
|
) -> SupportedNativeProviders:
|
||||||
"""Infer the provider from the model name.
|
"""Infer the provider from the model name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -167,13 +168,6 @@ class LLMMeta(ModelMetaclass):
|
|||||||
Returns:
|
Returns:
|
||||||
The inferred provider name, defaults to "openai"
|
The inferred provider name, defaults to "openai"
|
||||||
"""
|
"""
|
||||||
from crewai.llm.constants import (
|
|
||||||
ANTHROPIC_MODELS,
|
|
||||||
AZURE_MODELS,
|
|
||||||
BEDROCK_MODELS,
|
|
||||||
GEMINI_MODELS,
|
|
||||||
OPENAI_MODELS,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model in OPENAI_MODELS:
|
if model in OPENAI_MODELS:
|
||||||
return "openai"
|
return "openai"
|
||||||
@@ -193,7 +187,7 @@ class LLMMeta(ModelMetaclass):
|
|||||||
return "openai"
|
return "openai"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_native_provider(provider: str) -> type | None:
|
def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None:
|
||||||
"""Get native provider class if available.
|
"""Get native provider class if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -48,11 +48,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
client_params: Additional parameters for the Anthropic client
|
client_params: Additional parameters for the Anthropic client
|
||||||
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
|
||||||
ignored_types=(property,), arbitrary_types_allowed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
base_url: str | None = Field(
|
base_url: str | None = Field(
|
||||||
default=None, description="Custom base URL for Anthropic API"
|
default=None, description="Custom base URL for Anthropic API"
|
||||||
)
|
)
|
||||||
@@ -68,11 +63,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
client_params: dict[str, Any] | None = Field(
|
client_params: dict[str, Any] | None = Field(
|
||||||
default=None, description="Additional Anthropic client parameters"
|
default=None, description="Additional Anthropic client parameters"
|
||||||
)
|
)
|
||||||
interceptor: Any = Field(
|
client: Anthropic = Field(
|
||||||
default=None, description="HTTP interceptor for request/response modification"
|
default_factory=Anthropic, exclude=True, description="Anthropic client instance"
|
||||||
)
|
|
||||||
client: Any = Field(
|
|
||||||
default=None, exclude=True, description="Anthropic client instance"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_claude_3: bool = PrivateAttr(default=False)
|
_is_claude_3: bool = PrivateAttr(default=False)
|
||||||
|
|||||||
@@ -91,9 +91,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
default=None, description="Maximum tokens in response"
|
default=None, description="Maximum tokens in response"
|
||||||
)
|
)
|
||||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||||
interceptor: Any = Field(
|
|
||||||
default=None, description="HTTP interceptor (not yet supported for Azure)"
|
|
||||||
)
|
|
||||||
client: Any = Field(default=None, exclude=True, description="Azure client instance")
|
client: Any = Field(default=None, exclude=True, description="Azure client instance")
|
||||||
|
|
||||||
_is_openai_model: bool = PrivateAttr(default=False)
|
_is_openai_model: bool = PrivateAttr(default=False)
|
||||||
|
|||||||
@@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
top_p: Nucleus sampling parameter
|
top_p: Nucleus sampling parameter
|
||||||
top_k: Top-k sampling parameter (Claude models only)
|
top_k: Top-k sampling parameter (Claude models only)
|
||||||
stop_sequences: List of sequences that stop generation
|
|
||||||
stream: Whether to use streaming responses
|
stream: Whether to use streaming responses
|
||||||
guardrail_config: Guardrail configuration for content filtering
|
guardrail_config: Guardrail configuration for content filtering
|
||||||
additional_model_request_fields: Model-specific request parameters
|
additional_model_request_fields: Model-specific request parameters
|
||||||
@@ -192,9 +191,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
additional_model_response_field_paths: list[str] | None = Field(
|
additional_model_response_field_paths: list[str] | None = Field(
|
||||||
default=None, description="Custom response field paths"
|
default=None, description="Custom response field paths"
|
||||||
)
|
)
|
||||||
interceptor: Any = Field(
|
|
||||||
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
|
|
||||||
)
|
|
||||||
client: Any = Field(
|
client: Any = Field(
|
||||||
default=None, exclude=True, description="Bedrock client instance"
|
default=None, exclude=True, description="Bedrock client instance"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
top_p: Nucleus sampling parameter
|
top_p: Nucleus sampling parameter
|
||||||
top_k: Top-k sampling parameter
|
top_k: Top-k sampling parameter
|
||||||
max_output_tokens: Maximum tokens in response
|
max_output_tokens: Maximum tokens in response
|
||||||
stop_sequences: Stop sequences
|
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
safety_settings: Safety filter settings
|
safety_settings: Safety filter settings
|
||||||
client_params: Additional parameters for Google Gen AI Client constructor
|
client_params: Additional parameters for Google Gen AI Client constructor
|
||||||
@@ -70,9 +69,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Additional parameters for Google Gen AI Client constructor",
|
description="Additional parameters for Google Gen AI Client constructor",
|
||||||
)
|
)
|
||||||
interceptor: Any = Field(
|
|
||||||
default=None, description="HTTP interceptor (not yet supported for Gemini)"
|
|
||||||
)
|
|
||||||
client: Any = Field(
|
client: Any = Field(
|
||||||
default=None, exclude=True, description="Gemini client instance"
|
default=None, exclude=True, description="Gemini client instance"
|
||||||
)
|
)
|
||||||
@@ -81,28 +77,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
_is_gemini_1_5: bool = PrivateAttr(default=False)
|
_is_gemini_1_5: bool = PrivateAttr(default=False)
|
||||||
_supports_tools: bool = PrivateAttr(default=False)
|
_supports_tools: bool = PrivateAttr(default=False)
|
||||||
|
|
||||||
@property
|
|
||||||
def stop_sequences(self) -> list[str]:
|
|
||||||
"""Get stop sequences as a list.
|
|
||||||
|
|
||||||
This property provides access to stop sequences in Gemini's native format
|
|
||||||
while maintaining synchronization with the base class's stop attribute.
|
|
||||||
"""
|
|
||||||
if self.stop is None:
|
|
||||||
return []
|
|
||||||
if isinstance(self.stop, str):
|
|
||||||
return [self.stop]
|
|
||||||
return self.stop
|
|
||||||
|
|
||||||
@stop_sequences.setter
|
|
||||||
def stop_sequences(self, value: list[str] | str | None) -> None:
|
|
||||||
"""Set stop sequences, synchronizing with the stop attribute.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Stop sequences as a list, string, or None
|
|
||||||
"""
|
|
||||||
self.stop = value
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def setup_client(self) -> Self:
|
def setup_client(self) -> Self:
|
||||||
"""Initialize the Gemini client and validate configuration."""
|
"""Initialize the Gemini client and validate configuration."""
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ def test_anthropic_specific_parameters():
|
|||||||
|
|
||||||
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
|
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
|
||||||
assert isinstance(llm, AnthropicCompletion)
|
assert isinstance(llm, AnthropicCompletion)
|
||||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
assert llm.stop == ["Human:", "Assistant:"]
|
||||||
assert llm.stream == True
|
assert llm.stream == True
|
||||||
assert llm.client.max_retries == 5
|
assert llm.client.max_retries == 5
|
||||||
assert llm.client.timeout == 60
|
assert llm.client.timeout == 60
|
||||||
@@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking():
|
|||||||
|
|
||||||
|
|
||||||
def test_anthropic_stop_sequences_sync():
|
def test_anthropic_stop_sequences_sync():
|
||||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
"""Test that stop sequences can be set and retrieved correctly."""
|
||||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||||
|
|
||||||
# Test setting stop as a list
|
# Test setting stop as a list
|
||||||
llm.stop = ["\nObservation:", "\nThought:"]
|
llm.stop = ["\nObservation:", "\nThought:"]
|
||||||
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
|
|
||||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||||
|
|
||||||
# Test setting stop as a string
|
# Test setting stop as a string - note: setting via attribute doesn't go through validator
|
||||||
|
# so it stays as a string
|
||||||
llm.stop = "\nFinal Answer:"
|
llm.stop = "\nFinal Answer:"
|
||||||
assert llm.stop_sequences == ["\nFinal Answer:"]
|
assert llm.stop == "\nFinal Answer:"
|
||||||
assert llm.stop == ["\nFinal Answer:"]
|
|
||||||
|
|
||||||
# Test setting stop as None
|
# Test setting stop as None
|
||||||
llm.stop = None
|
llm.stop = None
|
||||||
assert llm.stop_sequences == []
|
assert llm.stop is None
|
||||||
assert llm.stop == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ def test_bedrock_specific_parameters():
|
|||||||
|
|
||||||
from crewai.llm.providers.bedrock.completion import BedrockCompletion
|
from crewai.llm.providers.bedrock.completion import BedrockCompletion
|
||||||
assert isinstance(llm, BedrockCompletion)
|
assert isinstance(llm, BedrockCompletion)
|
||||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
assert llm.stop == ["Human:", "Assistant:"]
|
||||||
assert llm.stream == True
|
assert llm.stream == True
|
||||||
assert llm.region_name == "us-east-1"
|
assert llm.region_name == "us-east-1"
|
||||||
|
|
||||||
@@ -739,23 +739,19 @@ def test_bedrock_client_error_handling():
|
|||||||
|
|
||||||
|
|
||||||
def test_bedrock_stop_sequences_sync():
|
def test_bedrock_stop_sequences_sync():
|
||||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
"""Test that stop sequences can be set and retrieved correctly."""
|
||||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||||
|
|
||||||
# Test setting stop as a list
|
# Test setting stop as a list
|
||||||
llm.stop = ["\nObservation:", "\nThought:"]
|
llm.stop = ["\nObservation:", "\nThought:"]
|
||||||
assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"]
|
|
||||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||||
|
|
||||||
# Test setting stop as a string
|
llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:")
|
||||||
llm.stop = "\nFinal Answer:"
|
assert llm2.stop == ["\nFinal Answer:"]
|
||||||
assert list(llm.stop_sequences) == ["\nFinal Answer:"]
|
|
||||||
assert llm.stop == ["\nFinal Answer:"]
|
|
||||||
|
|
||||||
# Test setting stop as None
|
# Test setting stop as None
|
||||||
llm.stop = None
|
llm.stop = None
|
||||||
assert list(llm.stop_sequences) == []
|
assert llm.stop is None
|
||||||
assert llm.stop == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_stop_sequences_sent_to_api():
|
def test_bedrock_stop_sequences_sent_to_api():
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ def test_gemini_specific_parameters():
|
|||||||
|
|
||||||
from crewai.llm.providers.gemini.completion import GeminiCompletion
|
from crewai.llm.providers.gemini.completion import GeminiCompletion
|
||||||
assert isinstance(llm, GeminiCompletion)
|
assert isinstance(llm, GeminiCompletion)
|
||||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
assert llm.stop == ["Human:", "Assistant:"]
|
||||||
assert llm.stream == True
|
assert llm.stream == True
|
||||||
assert llm.safety_settings == safety_settings
|
assert llm.safety_settings == safety_settings
|
||||||
assert llm.project == "test-project"
|
assert llm.project == "test-project"
|
||||||
@@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking():
|
|||||||
|
|
||||||
|
|
||||||
def test_gemini_stop_sequences_sync():
|
def test_gemini_stop_sequences_sync():
|
||||||
"""Test that stop and stop_sequences attributes stay synchronized."""
|
"""Test that stop sequences can be set and retrieved correctly."""
|
||||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||||
|
|
||||||
# Test setting stop as a list
|
# Test setting stop as a list
|
||||||
llm.stop = ["\nObservation:", "\nThought:"]
|
llm.stop = ["\nObservation:", "\nThought:"]
|
||||||
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
|
|
||||||
assert llm.stop == ["\nObservation:", "\nThought:"]
|
assert llm.stop == ["\nObservation:", "\nThought:"]
|
||||||
|
|
||||||
# Test setting stop as a string
|
# Test setting stop as a string
|
||||||
llm.stop = "\nFinal Answer:"
|
llm.stop = "\nFinal Answer:"
|
||||||
assert llm.stop_sequences == ["\nFinal Answer:"]
|
assert llm.stop == "\nFinal Answer:"
|
||||||
assert llm.stop == ["\nFinal Answer:"]
|
|
||||||
|
|
||||||
# Test setting stop as None
|
# Test setting stop as None
|
||||||
llm.stop = None
|
llm.stop = None
|
||||||
assert llm.stop_sequences == []
|
assert llm.stop is None
|
||||||
assert llm.stop == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_gemini_stop_sequences_sent_to_api():
|
def test_gemini_stop_sequences_sent_to_api():
|
||||||
|
|||||||
Reference in New Issue
Block a user