mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-30 02:28:13 +00:00
chore: move api key validation to base
This commit is contained in:
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
||||||
|
|
||||||
@@ -99,6 +100,24 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
|||||||
"cached_prompt_tokens": 0,
|
"cached_prompt_tokens": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@field_validator("api_key", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _validate_api_key(cls, value: str | None) -> str | None:
|
||||||
|
"""Validate API key for authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: API key value or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key from environment if not provided, or the original value
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
cls_name = cls.__name__
|
||||||
|
provider_prefix = cls_name.replace("Completion", "").upper()
|
||||||
|
env_var = f"{provider_prefix}_API_KEY"
|
||||||
|
value = os.getenv(env_var)
|
||||||
|
return value
|
||||||
|
|
||||||
@field_validator("stop", mode="before")
|
@field_validator("stop", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _normalize_stop(cls, value: Any) -> list[str]:
|
def _normalize_stop(cls, value: Any) -> list[str]:
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
import httpx
|
||||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@@ -29,7 +29,6 @@ try:
|
|||||||
from anthropic import Anthropic
|
from anthropic import Anthropic
|
||||||
from anthropic.types import Message
|
from anthropic.types import Message
|
||||||
from anthropic.types.tool_use_block import ToolUseBlock
|
from anthropic.types.tool_use_block import ToolUseBlock
|
||||||
import httpx
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Anthropic native provider not available, to install: uv add "crewai[anthropic]"'
|
'Anthropic native provider not available, to install: uv add "crewai[anthropic]"'
|
||||||
@@ -100,9 +99,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"""Get client parameters."""
|
"""Get client parameters."""
|
||||||
|
|
||||||
if self.api_key is None:
|
if self.api_key is None:
|
||||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||||
if self.api_key is None:
|
|
||||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
|
||||||
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
|
|||||||
@@ -107,9 +107,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.api_key is None:
|
|
||||||
self.api_key = os.getenv("AZURE_API_KEY")
|
|
||||||
|
|
||||||
if self.endpoint is None:
|
if self.endpoint is None:
|
||||||
self.endpoint = (
|
self.endpoint = (
|
||||||
os.getenv("AZURE_ENDPOINT")
|
os.getenv("AZURE_ENDPOINT")
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
|
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Required, Self
|
from typing_extensions import Required, Self
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
@@ -161,10 +161,6 @@ class BedrockCompletion(BaseLLM):
|
|||||||
interceptor: HTTP interceptor (not yet supported for Bedrock)
|
interceptor: HTTP interceptor (not yet supported for Bedrock)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
|
||||||
ignored_types=(property,), arbitrary_types_allowed=True
|
|
||||||
)
|
|
||||||
|
|
||||||
aws_access_key_id: str | None = Field(
|
aws_access_key_id: str | None = Field(
|
||||||
default=None, description="AWS access key (defaults to environment variable)"
|
default=None, description="AWS access key (defaults to environment variable)"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -81,8 +81,6 @@ class OpenAICompletion(BaseLLM):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def setup_client(self) -> Self:
|
def setup_client(self) -> Self:
|
||||||
"""Initialize OpenAI client after model validation."""
|
"""Initialize OpenAI client after model validation."""
|
||||||
if self.api_key is None:
|
|
||||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
|
|
||||||
client_config = self._get_client_params()
|
client_config = self._get_client_params()
|
||||||
if self.interceptor:
|
if self.interceptor:
|
||||||
@@ -101,9 +99,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
"""Get OpenAI client parameters."""
|
"""Get OpenAI client parameters."""
|
||||||
|
|
||||||
if self.api_key is None:
|
if self.api_key is None:
|
||||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
raise ValueError("OPENAI_API_KEY is required")
|
||||||
if self.api_key is None:
|
|
||||||
raise ValueError("OPENAI_API_KEY is required")
|
|
||||||
|
|
||||||
base_params = {
|
base_params = {
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
|
|||||||
Reference in New Issue
Block a user