mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
preparing new version
This commit is contained in:
@@ -123,6 +123,11 @@ class Agent(BaseAgent):
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
unnacepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
@@ -146,39 +151,49 @@ class Agent(BaseAgent):
|
||||
if api_base:
|
||||
llm_params["base_url"] = api_base
|
||||
|
||||
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
|
||||
|
||||
# Iterate over all environment variables to find matching API keys or use defaults
|
||||
for provider, env_vars in ENV_VARS.items():
|
||||
for env_var in env_vars:
|
||||
# Check if the environment variable is set
|
||||
if "key_name" in env_var:
|
||||
env_value = os.environ.get(env_var["key_name"])
|
||||
if env_value:
|
||||
# Map key names containing "API_KEY" to "api_key"
|
||||
key_name = (
|
||||
"api_key"
|
||||
if "API_KEY" in env_var["key_name"]
|
||||
else env_var["key_name"]
|
||||
if provider == set_provider:
|
||||
for env_var in env_vars:
|
||||
if env_var["key_name"] in unnacepted_attributes:
|
||||
continue
|
||||
# Check if the environment variable is set
|
||||
if "key_name" in env_var:
|
||||
env_value = os.environ.get(env_var["key_name"])
|
||||
print(
|
||||
f"Checking env var {env_var['key_name']}: {env_value}"
|
||||
)
|
||||
# Map key names containing "API_BASE" to "api_base"
|
||||
key_name = (
|
||||
"api_base"
|
||||
if "API_BASE" in env_var["key_name"]
|
||||
else key_name
|
||||
)
|
||||
# Map key names containing "API_VERSION" to "api_version"
|
||||
key_name = (
|
||||
"api_version"
|
||||
if "API_VERSION" in env_var["key_name"]
|
||||
else key_name
|
||||
)
|
||||
llm_params[key_name] = env_value
|
||||
# Check for default values if the environment variable is not set
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
if env_value:
|
||||
# Map key names containing "API_KEY" to "api_key"
|
||||
key_name = (
|
||||
"api_key"
|
||||
if "API_KEY" in env_var["key_name"]
|
||||
else env_var["key_name"]
|
||||
)
|
||||
# Map key names containing "API_BASE" to "api_base"
|
||||
key_name = (
|
||||
"api_base"
|
||||
if "API_BASE" in env_var["key_name"]
|
||||
else key_name
|
||||
)
|
||||
# Map key names containing "API_VERSION" to "api_version"
|
||||
key_name = (
|
||||
"api_version"
|
||||
if "API_VERSION" in env_var["key_name"]
|
||||
else key_name
|
||||
)
|
||||
print(f"Mapped key name: {key_name}")
|
||||
llm_params[key_name] = env_value
|
||||
# Check for default values if the environment variable is not set
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
print(f"Using default value for {key}: {value}")
|
||||
llm_params[key] = value
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
else:
|
||||
|
||||
@@ -332,9 +332,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||
train_iteration = self.crew._train_iteration
|
||||
if agent_id in training_data and isinstance(train_iteration, int):
|
||||
training_data[agent_id][train_iteration][
|
||||
"improved_output"
|
||||
] = result.output
|
||||
training_data[agent_id][train_iteration]["improved_output"] = (
|
||||
result.output
|
||||
)
|
||||
training_handler.save(training_data)
|
||||
else:
|
||||
self._logger.log(
|
||||
@@ -385,4 +385,5 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return CrewAgentParser(agent=self.agent).parse(answer)
|
||||
|
||||
def _format_msg(self, prompt: str, role: str = "user") -> Dict[str, str]:
|
||||
prompt = prompt.rstrip()
|
||||
return {"role": role, "content": prompt}
|
||||
|
||||
Reference in New Issue
Block a user