Merge branch 'main' of github.com:crewAIInc/crewAI into add/agent-specific-knowledge

This commit is contained in:
Lorenze Jay
2024-11-26 11:57:15 -08:00
10 changed files with 122 additions and 31 deletions

View File

@@ -13,10 +13,12 @@ from crewai.llm import LLM
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task
from crewai.tools import BaseTool
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -135,7 +137,7 @@ class Agent(BaseAgent):
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
unnacepted_attributes = [
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
@@ -169,28 +171,23 @@ class Agent(BaseAgent):
for provider, env_vars in ENV_VARS.items():
if provider == set_provider:
for env_var in env_vars:
if env_var["key_name"] in unnacepted_attributes:
continue
# Check if the environment variable is set
if "key_name" in env_var:
env_value = os.environ.get(env_var["key_name"])
key_name = env_var.get("key_name")
if key_name and key_name not in unaccepted_attributes:
env_value = os.environ.get(key_name)
if env_value:
# Map key names containing "API_KEY" to "api_key"
key_name = (
"api_key"
if "API_KEY" in env_var["key_name"]
else env_var["key_name"]
"api_key" if "API_KEY" in key_name else key_name
)
# Map key names containing "API_BASE" to "api_base"
key_name = (
"api_base"
if "API_BASE" in env_var["key_name"]
else key_name
"api_base" if "API_BASE" in key_name else key_name
)
# Map key names containing "API_VERSION" to "api_version"
key_name = (
"api_version"
if "API_VERSION" in env_var["key_name"]
if "API_VERSION" in key_name
else key_name
)
llm_params[key_name] = env_value
@@ -264,7 +261,7 @@ class Agent(BaseAgent):
def execute_task(
self,
task: Any,
task: Task,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
) -> str:
@@ -283,6 +280,22 @@ class Agent(BaseAgent):
task_prompt = task.prompt()
# If the task requires output in JSON or Pydantic format,
# append specific instructions to the task prompt to ensure
# that the final answer does not include any code block markers
if task.output_json or task.output_pydantic:
# Generate the schema based on the output format
if task.output_json:
# schema = json.dumps(task.output_json, indent=2)
schema = generate_model_description(task.output_json)
elif task.output_pydantic:
schema = generate_model_description(task.output_pydantic)
task_prompt += "\n" + self.i18n.slice("formatted_task_instructions").format(
output_format=schema
)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context