From 37bb874edba67f114657ac63a6fecf7503c320bb Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Tue, 3 Dec 2024 15:47:25 -0500 Subject: [PATCH] Incorporate feedback from crewai reviewer --- src/crewai/agent.py | 19 +++++++------------ src/crewai/cli/constants.py | 3 +++ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/crewai/agent.py b/src/crewai/agent.py index f330f06f9..8c79c6eb8 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -8,7 +8,7 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor -from crewai.cli.constants import ENV_VARS +from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context @@ -181,17 +181,12 @@ class Agent(BaseAgent): if key_name and key_name not in unaccepted_attributes: env_value = os.environ.get(key_name) if env_value: - param_name = key_name.lower() - # Map key names containing "API_KEY" to "api_key" - if "api_key" in param_name: - param_name = "api_key" - # Map key names containing "API_BASE" to "api_base" - elif "api_base" in param_name: - param_name = "api_base" - # Map key names containing "API_VERSION" to "api_version" - elif "api_version" in param_name: - param_name = "api_version" - llm_params[param_name] = env_value + key_name = key_name.lower() + for pattern in LITELLM_PARAMS: + if pattern in key_name: + key_name = pattern + break + llm_params[key_name] = env_value # Check for default values if the environment variable is not set elif env_var.get("default", False): for key, value in env_var.items(): diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index e13349155..13279f8d3 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -159,3 +159,6 @@ MODELS = { } JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" + + +LITELLM_PARAMS = ["api_key", "api_base", "api_version"]