chore: move api key validation to base

This commit is contained in:
Greyson LaLonde
2025-11-11 17:46:26 -05:00
parent 0803318002
commit 93f1fbd75e
5 changed files with 24 additions and 19 deletions

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
import json import json
import logging import logging
import os
import re import re
from typing import TYPE_CHECKING, Any, ClassVar, Final from typing import TYPE_CHECKING, Any, ClassVar, Final
@@ -99,6 +100,24 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
"cached_prompt_tokens": 0, "cached_prompt_tokens": 0,
} }
@field_validator("api_key", mode="before")
@classmethod
def _validate_api_key(cls, value: str | None) -> str | None:
"""Validate API key for authentication.
Args:
value: API key value or None
Returns:
API key from environment if not provided, or the original value
"""
if value is None:
cls_name = cls.__name__
provider_prefix = cls_name.replace("Completion", "").upper()
env_var = f"{provider_prefix}_API_KEY"
value = os.getenv(env_var)
return value
@field_validator("stop", mode="before") @field_validator("stop", mode="before")
@classmethod @classmethod
def _normalize_stop(cls, value: Any) -> list[str]: def _normalize_stop(cls, value: Any) -> list[str]:

View File

@@ -2,9 +2,9 @@ from __future__ import annotations
import json import json
import logging import logging
import os
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import httpx
from pydantic import BaseModel, Field, PrivateAttr, model_validator from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self from typing_extensions import Self
@@ -29,7 +29,6 @@ try:
from anthropic import Anthropic from anthropic import Anthropic
from anthropic.types import Message from anthropic.types import Message
from anthropic.types.tool_use_block import ToolUseBlock from anthropic.types.tool_use_block import ToolUseBlock
import httpx
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Anthropic native provider not available, to install: uv add "crewai[anthropic]"' 'Anthropic native provider not available, to install: uv add "crewai[anthropic]"'
@@ -100,9 +99,7 @@ class AnthropicCompletion(BaseLLM):
"""Get client parameters.""" """Get client parameters."""
if self.api_key is None: if self.api_key is None:
self.api_key = os.getenv("ANTHROPIC_API_KEY") raise ValueError("ANTHROPIC_API_KEY is required")
if self.api_key is None:
raise ValueError("ANTHROPIC_API_KEY is required")
client_params = { client_params = {
"api_key": self.api_key, "api_key": self.api_key,

View File

@@ -107,9 +107,6 @@ class AzureCompletion(BaseLLM):
"Interceptors are currently supported for OpenAI and Anthropic providers only." "Interceptors are currently supported for OpenAI and Anthropic providers only."
) )
if self.api_key is None:
self.api_key = os.getenv("AZURE_API_KEY")
if self.endpoint is None: if self.endpoint is None:
self.endpoint = ( self.endpoint = (
os.getenv("AZURE_ENDPOINT") os.getenv("AZURE_ENDPOINT")

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from typing import TYPE_CHECKING, Any, TypedDict, cast
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Required, Self from typing_extensions import Required, Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
@@ -161,10 +161,6 @@ class BedrockCompletion(BaseLLM):
interceptor: HTTP interceptor (not yet supported for Bedrock) interceptor: HTTP interceptor (not yet supported for Bedrock)
""" """
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
aws_access_key_id: str | None = Field( aws_access_key_id: str | None = Field(
default=None, description="AWS access key (defaults to environment variable)" default=None, description="AWS access key (defaults to environment variable)"
) )

View File

@@ -81,8 +81,6 @@ class OpenAICompletion(BaseLLM):
@model_validator(mode="after") @model_validator(mode="after")
def setup_client(self) -> Self: def setup_client(self) -> Self:
"""Initialize OpenAI client after model validation.""" """Initialize OpenAI client after model validation."""
if self.api_key is None:
self.api_key = os.getenv("OPENAI_API_KEY")
client_config = self._get_client_params() client_config = self._get_client_params()
if self.interceptor: if self.interceptor:
@@ -101,9 +99,7 @@ class OpenAICompletion(BaseLLM):
"""Get OpenAI client parameters.""" """Get OpenAI client parameters."""
if self.api_key is None: if self.api_key is None:
self.api_key = os.getenv("OPENAI_API_KEY") raise ValueError("OPENAI_API_KEY is required")
if self.api_key is None:
raise ValueError("OPENAI_API_KEY is required")
base_params = { base_params = {
"api_key": self.api_key, "api_key": self.api_key,