mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Incorporate Stale PRs that have feedback (#1693)
* incorporate #1683 * add in --version flag to cli. closes #1679. * Fix env issue * Add in suggestions from @caike to make sure ragstorage doesnt exceed os file limit. Also, included additional checks to support windows. * remove poetry.lock as pointed out by @sanders41 in #1574. * Incorporate feedback from crewai reviewer * Incorporate @lorenzejay feedback
This commit is contained in:
committed by
GitHub
parent
3daba0c79e
commit
7b276e6797
@@ -8,7 +8,7 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
|||||||
from crewai.agents import CacheHandler
|
from crewai.agents import CacheHandler
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
from crewai.cli.constants import ENV_VARS
|
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
@@ -181,20 +181,11 @@ class Agent(BaseAgent):
|
|||||||
if key_name and key_name not in unaccepted_attributes:
|
if key_name and key_name not in unaccepted_attributes:
|
||||||
env_value = os.environ.get(key_name)
|
env_value = os.environ.get(key_name)
|
||||||
if env_value:
|
if env_value:
|
||||||
# Map key names containing "API_KEY" to "api_key"
|
key_name = key_name.lower()
|
||||||
key_name = (
|
for pattern in LITELLM_PARAMS:
|
||||||
"api_key" if "API_KEY" in key_name else key_name
|
if pattern in key_name:
|
||||||
)
|
key_name = pattern
|
||||||
# Map key names containing "API_BASE" to "api_base"
|
break
|
||||||
key_name = (
|
|
||||||
"api_base" if "API_BASE" in key_name else key_name
|
|
||||||
)
|
|
||||||
# Map key names containing "API_VERSION" to "api_version"
|
|
||||||
key_name = (
|
|
||||||
"api_version"
|
|
||||||
if "API_VERSION" in key_name
|
|
||||||
else key_name
|
|
||||||
)
|
|
||||||
llm_params[key_name] = env_value
|
llm_params[key_name] = env_value
|
||||||
# Check for default values if the environment variable is not set
|
# Check for default values if the environment variable is not set
|
||||||
elif env_var.get("default", False):
|
elif env_var.get("default", False):
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from .update_crew import update_crew
|
|||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
@click.version_option(pkg_resources.get_distribution("crewai").version)
|
||||||
def crewai():
|
def crewai():
|
||||||
"""Top-level command group for crewai."""
|
"""Top-level command group for crewai."""
|
||||||
|
|
||||||
@@ -50,7 +51,10 @@ def create(type, name, provider, skip_provider=False):
|
|||||||
)
|
)
|
||||||
def version(tools):
|
def version(tools):
|
||||||
"""Show the installed version of crewai."""
|
"""Show the installed version of crewai."""
|
||||||
crewai_version = pkg_resources.get_distribution("crewai").version
|
try:
|
||||||
|
crewai_version = pkg_resources.get_distribution("crewai").version
|
||||||
|
except Exception:
|
||||||
|
crewai_version = "unknown version"
|
||||||
click.echo(f"crewai version: {crewai_version}")
|
click.echo(f"crewai version: {crewai_version}")
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -159,3 +159,6 @@ MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
|
||||||
|
|
||||||
|
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities.paths import db_storage_path
|
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||||
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -37,12 +39,15 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
|
|
||||||
app: ClientAPI | None = None
|
app: ClientAPI | None = None
|
||||||
|
|
||||||
def __init__(self, type, allow_reset=True, embedder_config=None, crew=None, path=None):
|
def __init__(
|
||||||
|
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||||
|
):
|
||||||
super().__init__(type, allow_reset, embedder_config, crew)
|
super().__init__(type, allow_reset, embedder_config, crew)
|
||||||
agents = crew.agents if crew else []
|
agents = crew.agents if crew else []
|
||||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||||
agents = "_".join(agents)
|
agents = "_".join(agents)
|
||||||
self.agents = agents
|
self.agents = agents
|
||||||
|
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||||
|
|
||||||
self.type = type
|
self.type = type
|
||||||
|
|
||||||
@@ -60,7 +65,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
|
|
||||||
self._set_embedder_config()
|
self._set_embedder_config()
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=self.path if self.path else f"{db_storage_path()}/{self.type}/{self.agents}",
|
path=self.path if self.path else self.storage_file_name,
|
||||||
settings=Settings(allow_reset=self.allow_reset),
|
settings=Settings(allow_reset=self.allow_reset),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,6 +86,20 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"""
|
"""
|
||||||
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
||||||
|
|
||||||
|
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensures file name does not exceed max allowed by OS
|
||||||
|
"""
|
||||||
|
base_path = f"{db_storage_path()}/{type}"
|
||||||
|
|
||||||
|
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||||
|
logging.warning(
|
||||||
|
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
||||||
|
)
|
||||||
|
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
||||||
|
|
||||||
|
return f"{base_path}/{file_name}"
|
||||||
|
|
||||||
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||||
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
|
|||||||
DEFAULT_SCORE_THRESHOLD = 0.35
|
DEFAULT_SCORE_THRESHOLD = 0.35
|
||||||
KNOWLEDGE_DIRECTORY = "knowledge"
|
KNOWLEDGE_DIRECTORY = "knowledge"
|
||||||
MAX_LLM_RETRY = 3
|
MAX_LLM_RETRY = 3
|
||||||
|
MAX_FILE_NAME_LENGTH = 255
|
||||||
|
|||||||
@@ -131,6 +131,13 @@ def test_reset_no_memory_flags(runner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_version_flag(runner):
|
||||||
|
result = runner.invoke(version)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "crewai version:" in result.output
|
||||||
|
|
||||||
|
|
||||||
def test_version_command(runner):
|
def test_version_command(runner):
|
||||||
result = runner.invoke(version)
|
result = runner.invoke(version)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user