mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
chore: move api key validation to base
This commit is contained in:
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
||||
|
||||
@@ -99,6 +100,24 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def _validate_api_key(cls, value: str | None) -> str | None:
|
||||
"""Validate API key for authentication.
|
||||
|
||||
Args:
|
||||
value: API key value or None
|
||||
|
||||
Returns:
|
||||
API key from environment if not provided, or the original value
|
||||
"""
|
||||
if value is None:
|
||||
cls_name = cls.__name__
|
||||
provider_prefix = cls_name.replace("Completion", "").upper()
|
||||
env_var = f"{provider_prefix}_API_KEY"
|
||||
value = os.getenv(env_var)
|
||||
return value
|
||||
|
||||
@field_validator("stop", mode="before")
|
||||
@classmethod
|
||||
def _normalize_stop(cls, value: Any) -> list[str]:
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -29,7 +29,6 @@ try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Anthropic native provider not available, to install: uv add "crewai[anthropic]"'
|
||||
@@ -100,9 +99,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"""Get client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
|
||||
client_params = {
|
||||
"api_key": self.api_key,
|
||||
|
||||
@@ -107,9 +107,6 @@ class AzureCompletion(BaseLLM):
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("AZURE_API_KEY")
|
||||
|
||||
if self.endpoint is None:
|
||||
self.endpoint = (
|
||||
os.getenv("AZURE_ENDPOINT")
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Required, Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -161,10 +161,6 @@ class BedrockCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock)
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
ignored_types=(property,), arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
aws_access_key_id: str | None = Field(
|
||||
default=None, description="AWS access key (defaults to environment variable)"
|
||||
)
|
||||
|
||||
@@ -81,8 +81,6 @@ class OpenAICompletion(BaseLLM):
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize OpenAI client after model validation."""
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -101,9 +99,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
|
||||
base_params = {
|
||||
"api_key": self.api_key,
|
||||
|
||||
Reference in New Issue
Block a user