Incorporate Stale PRs that have feedback (#1693)

* incorporate #1683

* add in --version flag to cli. closes #1679.

* Fix env issue

* Add in suggestions from @caike to make sure ragstorage doesnt exceed os file limit. Also, included additional checks to support windows.

* remove poetry.lock as pointed out by @sanders41 in #1574.

* Incorporate feedback from crewai reviewer

* Incorporate @lorenzejay feedback
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-12-05 12:17:23 -05:00
committed by GitHub
parent 3daba0c79e
commit 7b276e6797
6 changed files with 45 additions and 20 deletions

View File

@@ -8,7 +8,7 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
@@ -181,20 +181,11 @@ class Agent(BaseAgent):
if key_name and key_name not in unaccepted_attributes: if key_name and key_name not in unaccepted_attributes:
env_value = os.environ.get(key_name) env_value = os.environ.get(key_name)
if env_value: if env_value:
# Map key names containing "API_KEY" to "api_key" key_name = key_name.lower()
key_name = ( for pattern in LITELLM_PARAMS:
"api_key" if "API_KEY" in key_name else key_name if pattern in key_name:
) key_name = pattern
# Map key names containing "API_BASE" to "api_base" break
key_name = (
"api_base" if "API_BASE" in key_name else key_name
)
# Map key names containing "API_VERSION" to "api_version"
key_name = (
"api_version"
if "API_VERSION" in key_name
else key_name
)
llm_params[key_name] = env_value llm_params[key_name] = env_value
# Check for default values if the environment variable is not set # Check for default values if the environment variable is not set
elif env_var.get("default", False): elif env_var.get("default", False):

View File

@@ -25,6 +25,7 @@ from .update_crew import update_crew
@click.group() @click.group()
@click.version_option(pkg_resources.get_distribution("crewai").version)
def crewai(): def crewai():
"""Top-level command group for crewai.""" """Top-level command group for crewai."""
@@ -50,7 +51,10 @@ def create(type, name, provider, skip_provider=False):
) )
def version(tools): def version(tools):
"""Show the installed version of crewai.""" """Show the installed version of crewai."""
crewai_version = pkg_resources.get_distribution("crewai").version try:
crewai_version = pkg_resources.get_distribution("crewai").version
except Exception:
crewai_version = "unknown version"
click.echo(f"crewai version: {crewai_version}") click.echo(f"crewai version: {crewai_version}")
if tools: if tools:

View File

@@ -159,3 +159,6 @@ MODELS = {
} }
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]

View File

@@ -4,12 +4,14 @@ import logging
import os import os
import shutil import shutil
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities.paths import db_storage_path
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager @contextlib.contextmanager
@@ -37,12 +39,15 @@ class RAGStorage(BaseRAGStorage):
app: ClientAPI | None = None app: ClientAPI | None = None
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None, path=None): def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
):
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents) agents = "_".join(agents)
self.agents = agents self.agents = agents
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type self.type = type
@@ -60,7 +65,7 @@ class RAGStorage(BaseRAGStorage):
self._set_embedder_config() self._set_embedder_config()
chroma_client = chromadb.PersistentClient( chroma_client = chromadb.PersistentClient(
path=self.path if self.path else f"{db_storage_path()}/{self.type}/{self.agents}", path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset), settings=Settings(allow_reset=self.allow_reset),
) )
@@ -81,6 +86,20 @@ class RAGStorage(BaseRAGStorage):
""" """
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
)
file_name = file_name[:MAX_FILE_NAME_LENGTH]
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app") or not hasattr(self, "collection"): if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app() self._initialize_app()

View File

@@ -3,3 +3,4 @@ TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
DEFAULT_SCORE_THRESHOLD = 0.35 DEFAULT_SCORE_THRESHOLD = 0.35
KNOWLEDGE_DIRECTORY = "knowledge" KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3 MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255

View File

@@ -131,6 +131,13 @@ def test_reset_no_memory_flags(runner):
) )
def test_version_flag(runner):
result = runner.invoke(version)
assert result.exit_code == 0
assert "crewai version:" in result.output
def test_version_command(runner): def test_version_command(runner):
result = runner.invoke(version) result = runner.invoke(version)