mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 06:38:29 +00:00
Compare commits
2 Commits
devin/1765
...
devin/1758
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b046c4881 | ||
|
|
b14bcd01c5 |
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import AliasChoices, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
@@ -21,7 +21,8 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
|
||||
description="Watson model ID",
|
||||
validation_alias=AliasChoices("WATSONX_MODEL_ID", "WATSON_MODEL_ID"),
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
@@ -30,7 +31,7 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="Watson project ID",
|
||||
validation_alias="WATSON_PROJECT_ID",
|
||||
validation_alias=AliasChoices("WATSONX_PROJECT_ID", "WATSON_PROJECT_ID"),
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
||||
@@ -67,9 +68,13 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
|
||||
url: str = Field(
|
||||
description="Watson API URL",
|
||||
validation_alias=AliasChoices("WATSONX_URL", "WATSON_URL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Watson API key", validation_alias="WATSON_API_KEY"
|
||||
description="Watson API key",
|
||||
validation_alias=AliasChoices("WATSONX_APIKEY", "WATSON_API_KEY"),
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None, description="Service name", validation_alias="WATSON_NAME"
|
||||
@@ -85,7 +90,9 @@ class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
validation_alias="WATSON_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias=AliasChoices("WATSONX_TOKEN", "WATSON_TOKEN"),
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
|
||||
140
tests/rag/embeddings/providers/ibm/test_watson_env_vars.py
Normal file
140
tests/rag/embeddings/providers/ibm/test_watson_env_vars.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Tests for Watson provider environment variable handling."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.watson import WatsonProvider
|
||||
|
||||
|
||||
class TestWatsonEnvironmentVariables:
|
||||
"""Test Watson provider environment variable compatibility."""
|
||||
|
||||
def test_watsonx_prefix_variables(self):
|
||||
"""Test that WATSONX_ prefixed variables work correctly."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_URL": "https://us-south.ml.cloud.ibm.com",
|
||||
"WATSONX_APIKEY": "test-api-key",
|
||||
"WATSONX_PROJECT_ID": "test-project-id",
|
||||
"WATSONX_MODEL_ID": "ibm/slate-125m-english-rtrvr",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = WatsonProvider()
|
||||
assert provider.url == "https://us-south.ml.cloud.ibm.com"
|
||||
assert provider.api_key == "test-api-key"
|
||||
assert provider.project_id == "test-project-id"
|
||||
assert provider.model_id == "ibm/slate-125m-english-rtrvr"
|
||||
|
||||
def test_watson_prefix_backward_compatibility(self):
|
||||
"""Test that legacy WATSON_ prefixed variables still work."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSON_URL": "https://us-south.ml.cloud.ibm.com",
|
||||
"WATSON_API_KEY": "test-api-key",
|
||||
"WATSON_PROJECT_ID": "test-project-id",
|
||||
"WATSON_MODEL_ID": "ibm/slate-125m-english-rtrvr",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = WatsonProvider()
|
||||
assert provider.url == "https://us-south.ml.cloud.ibm.com"
|
||||
assert provider.api_key == "test-api-key"
|
||||
assert provider.project_id == "test-project-id"
|
||||
assert provider.model_id == "ibm/slate-125m-english-rtrvr"
|
||||
|
||||
def test_watsonx_takes_precedence_over_watson(self):
|
||||
"""Test that WATSONX_ variables take precedence over WATSON_ when both are set."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_URL": "https://new-url.com",
|
||||
"WATSON_URL": "https://old-url.com",
|
||||
"WATSONX_APIKEY": "new-key",
|
||||
"WATSON_API_KEY": "old-key",
|
||||
"WATSONX_PROJECT_ID": "new-project",
|
||||
"WATSON_PROJECT_ID": "old-project",
|
||||
"WATSONX_MODEL_ID": "new-model",
|
||||
"WATSON_MODEL_ID": "old-model",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = WatsonProvider()
|
||||
assert provider.url == "https://new-url.com"
|
||||
assert provider.api_key == "new-key"
|
||||
assert provider.project_id == "new-project"
|
||||
assert provider.model_id == "new-model"
|
||||
|
||||
def test_mixed_environment_variables(self):
|
||||
"""Test that mixing WATSONX_ and WATSON_ variables works correctly."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_URL": "https://us-south.ml.cloud.ibm.com",
|
||||
"WATSON_API_KEY": "test-api-key",
|
||||
"WATSONX_PROJECT_ID": "test-project-id",
|
||||
"WATSON_MODEL_ID": "ibm/slate-125m-english-rtrvr",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = WatsonProvider()
|
||||
assert provider.url == "https://us-south.ml.cloud.ibm.com"
|
||||
assert provider.api_key == "test-api-key"
|
||||
assert provider.project_id == "test-project-id"
|
||||
assert provider.model_id == "ibm/slate-125m-english-rtrvr"
|
||||
|
||||
def test_token_environment_variables(self):
|
||||
"""Test that token environment variables work with both prefixes."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_URL": "https://us-south.ml.cloud.ibm.com",
|
||||
"WATSONX_APIKEY": "test-api-key",
|
||||
"WATSONX_PROJECT_ID": "test-project-id",
|
||||
"WATSONX_MODEL_ID": "ibm/slate-125m-english-rtrvr",
|
||||
"WATSONX_TOKEN": "test-token",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = WatsonProvider()
|
||||
assert provider.token == "test-token" # noqa: S105
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSON_URL": "https://us-south.ml.cloud.ibm.com",
|
||||
"WATSON_API_KEY": "test-api-key",
|
||||
"WATSON_PROJECT_ID": "test-project-id",
|
||||
"WATSON_MODEL_ID": "ibm/slate-125m-english-rtrvr",
|
||||
"WATSON_TOKEN": "legacy-token",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = WatsonProvider()
|
||||
assert provider.token == "legacy-token" # noqa: S105
|
||||
|
||||
def test_validation_error_when_required_fields_missing(self):
|
||||
"""Test that validation errors are raised when required fields are missing."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError):
|
||||
WatsonProvider()
|
||||
|
||||
def test_space_or_project_validation(self):
|
||||
"""Test that either space_id or project_id must be provided."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"WATSONX_URL": "https://us-south.ml.cloud.ibm.com",
|
||||
"WATSONX_APIKEY": "test-api-key",
|
||||
"WATSONX_MODEL_ID": "ibm/slate-125m-english-rtrvr",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="One of 'space_id' or 'project_id' must be provided"
|
||||
):
|
||||
WatsonProvider()
|
||||
Reference in New Issue
Block a user