From d9b29daa275aafe9aa733be43344750c150d78c6 Mon Sep 17 00:00:00 2001 From: Michael Barnathan Date: Sun, 16 Jun 2024 15:21:31 -0400 Subject: [PATCH] Use CREWAI_PROMPT_FILE environment variable. CrewAI initializes several I18N instances deep within the library at package initialization time, using the default prompts. This makes overriding the defaults consistently throughout the package very difficult, requiring monkeypatching. This change will allow overriding the default prompt file location in the $CREWAI_PROMPT_FILE environment variable, allowing consistency throughout the library. --- src/crewai/utilities/i18n.py | 10 +++++++--- tests/utilities/test_i18n.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/crewai/utilities/i18n.py b/src/crewai/utilities/i18n.py index b283f57c0..1e63c8caa 100644 --- a/src/crewai/utilities/i18n.py +++ b/src/crewai/utilities/i18n.py @@ -9,7 +9,9 @@ class I18N(BaseModel): _prompts: Dict[str, Dict[str, str]] = PrivateAttr() prompt_file: Optional[str] = Field( default=None, - description="Path to the prompt_file file to load", + description="Path to the prompt_file file to load. " + "If not provided, $CREWAI_PROMPT_FILE will be checked. " + "If also not set, uses the default prompts" ) @model_validator(mode="after") @@ -20,8 +22,10 @@ class I18N(BaseModel): with open(self.prompt_file, "r") as f: self._prompts = json.load(f) else: - dir_path = os.path.dirname(os.path.realpath(__file__)) - prompts_path = os.path.join(dir_path, "../translations/en.json") + prompts_path = os.environ.get("CREWAI_PROMPT_FILE") + if not prompts_path: + dir_path = os.path.dirname(os.path.realpath(__file__)) + prompts_path = os.path.join(dir_path, "../translations/en.json") with open(prompts_path, "r") as f: self._prompts = json.load(f) diff --git a/tests/utilities/test_i18n.py b/tests/utilities/test_i18n.py index 8627b0bec..e6887f39b 100644 --- a/tests/utilities/test_i18n.py +++ b/tests/utilities/test_i18n.py @@ -42,3 +42,24 @@ def test_prompt_file(): i18n.load_prompts() assert isinstance(i18n.retrieve("slices", "role_playing"), str) assert i18n.retrieve("slices", "role_playing") == "Lorem ipsum dolor sit amet" + + +def test_prompt_file_env(): + import os + + path = os.path.join(os.path.dirname(__file__), "prompts.json") + old_env = os.environ.get("CREWAI_PROMPT_FILE") + try: + os.environ["CREWAI_PROMPT_FILE"] = path + i18n = I18N() + i18n.load_prompts() + assert i18n.retrieve("slices", "role_playing") == "Lorem ipsum dolor sit amet" + finally: + if old_env: + os.environ["CREWAI_PROMPT_FILE"] = old_env + else: + del os.environ["CREWAI_PROMPT_FILE"] + + i18n = I18N() + i18n.load_prompts() + assert i18n.retrieve("slices", "role_playing") != "Lorem ipsum dolor sit amet"