mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff16348d4c | ||
|
|
7310f4d85b | ||
|
|
ac331504e9 | ||
|
|
6823f76ff4 | ||
|
|
c3ac3219fe | ||
|
|
104ef7a0c2 | ||
|
|
2bbf8ed8a8 | ||
|
|
5dc6644ac7 | ||
|
|
9c0f97eaf7 | ||
|
|
164e7895bf | ||
|
|
fb46fb9ca3 | ||
|
|
effb7efc37 | ||
|
|
f5098e7e45 | ||
|
|
b15d632308 |
@@ -162,7 +162,7 @@ nav:
|
||||
- Directory RAG Search: 'tools/DirectorySearchTool.md'
|
||||
- Directory Read: 'tools/DirectoryReadTool.md'
|
||||
- Docx Rag Search: 'tools/DOCXSearchTool.md'
|
||||
- EXA Serch Web Loader: 'tools/EXASearchTool.md'
|
||||
- EXA Search Web Loader: 'tools/EXASearchTool.md'
|
||||
- File Read: 'tools/FileReadTool.md'
|
||||
- File Write: 'tools/FileWriteTool.md'
|
||||
- Firecrawl Crawl Website Tool: 'tools/FirecrawlCrawlWebsiteTool.md'
|
||||
@@ -210,6 +210,6 @@ extra:
|
||||
property: G-N3Q505TMQ6
|
||||
social:
|
||||
- icon: fontawesome/brands/twitter
|
||||
link: https://twitter.com/joaomdmoura
|
||||
link: https://x.com/crewAIInc
|
||||
- icon: fontawesome/brands/github
|
||||
link: https://github.com/joaomdmoura/crewAI
|
||||
link: https://github.com/crewAIInc/crewAI
|
||||
|
||||
60
poetry.lock
generated
60
poetry.lock
generated
@@ -390,17 +390,17 @@ lxml = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.35.26"
|
||||
version = "1.35.27"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "boto3-1.35.26-py3-none-any.whl", hash = "sha256:c31db992655db233d98762612690cfe60723c9e1503b5709aad92c1c564877bb"},
|
||||
{file = "boto3-1.35.26.tar.gz", hash = "sha256:b04087afd3570ba540fd293823c77270ec675672af23da9396bd5988a3f8128b"},
|
||||
{file = "boto3-1.35.27-py3-none-any.whl", hash = "sha256:3da139ca038032e92086e26d23833b557f0c257520162bfd3d6f580bf8032c86"},
|
||||
{file = "boto3-1.35.27.tar.gz", hash = "sha256:10d0fe15670b83a3f26572ab20d9152a064cee4c54b5ea9a1eeb1f0c3b807a7b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.35.26,<1.36.0"
|
||||
botocore = ">=1.35.27,<1.36.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.10.0,<0.11.0"
|
||||
|
||||
@@ -409,13 +409,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.35.26"
|
||||
version = "1.35.27"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "botocore-1.35.26-py3-none-any.whl", hash = "sha256:0b9dee5e4a3314e251e103585837506b17fcc7485c3c8adb61a9a913f46da1e7"},
|
||||
{file = "botocore-1.35.26.tar.gz", hash = "sha256:19efc3a22c9df77960712b4e203f912486f8bcd3794bff0fd7b2a0f5f1d5712d"},
|
||||
{file = "botocore-1.35.27-py3-none-any.whl", hash = "sha256:c299c70b5330a8634e032883ce8a72c2c6d9fdbc985d8191199cb86b92e7cbbd"},
|
||||
{file = "botocore-1.35.27.tar.gz", hash = "sha256:f68875c26cd57a9d22c0f7a981ecb1636d7ce4d0e35797e04765b53e7bfed3e7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -833,13 +833,13 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[[package]]
|
||||
name = "cohere"
|
||||
version = "5.9.4"
|
||||
version = "5.10.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "cohere-5.9.4-py3-none-any.whl", hash = "sha256:d1b31d8ba32e338b3aa91737aa98dc74de8778ed8e397ab799739b5f060f44e7"},
|
||||
{file = "cohere-5.9.4.tar.gz", hash = "sha256:ed0fa256c51423175c208650dffcb534ae112dc3ab7703de352e2adaf99dd50b"},
|
||||
{file = "cohere-5.10.0-py3-none-any.whl", hash = "sha256:46e50e3e8514a99cf77b4c022c8077a6205fba948051c33087ddeb66ec706f0a"},
|
||||
{file = "cohere-5.10.0.tar.gz", hash = "sha256:21020a7ae4c30f72991ef91566a926a9d7d1485d7abeed7bfa2bd6f35ea34783"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2864,13 +2864,13 @@ requests = ">=2,<3"
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.48.0"
|
||||
version = "1.48.2"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
optional = false
|
||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
|
||||
files = [
|
||||
{file = "litellm-1.48.0-py3-none-any.whl", hash = "sha256:7765e8a92069778f5fc66aacfabd0e2f8ec8d74fb117f5e475567d89b0d376b9"},
|
||||
{file = "litellm-1.48.0.tar.gz", hash = "sha256:31a9b8a25a9daf44c24ddc08bf74298da920f2c5cea44135e5061278d0aa6fc9"},
|
||||
{file = "litellm-1.48.2-py3-none-any.whl", hash = "sha256:4819f126be8e3cabfa9b1aadd44459442828feb0ca17538189cff48ac32939d9"},
|
||||
{file = "litellm-1.48.2.tar.gz", hash = "sha256:1c01dd5035936b8a6813e2c26f029fc227f1644eb842ceb8f787deae0bfae57e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3049,13 +3049,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "mem0ai"
|
||||
version = "0.1.15"
|
||||
version = "0.1.16"
|
||||
description = "Long-term memory for AI Agents"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "mem0ai-0.1.15-py3-none-any.whl", hash = "sha256:b553a8a3491565f9e52de535282304218a3773069cb4d41d2da6b4236a7b06bc"},
|
||||
{file = "mem0ai-0.1.15.tar.gz", hash = "sha256:8f0d6207da36f9ffcb3b1c5ee366f1d0f550daf577d875abc96477309ef233c8"},
|
||||
{file = "mem0ai-0.1.16-py3-none-any.whl", hash = "sha256:fa302457668a9a8994f0865741517f6d88b844f14dd3e8f6c17a30a3eed343a9"},
|
||||
{file = "mem0ai-0.1.16.tar.gz", hash = "sha256:2b4f1f94fc78796b483f6440230a9a68c9815068e16b5092c2b9758635511a2c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3144,13 +3144,13 @@ pyyaml = ">=5.1"
|
||||
|
||||
[[package]]
|
||||
name = "mkdocs-material"
|
||||
version = "9.5.36"
|
||||
version = "9.5.37"
|
||||
description = "Documentation that simply works"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "mkdocs_material-9.5.36-py3-none-any.whl", hash = "sha256:36734c1fd9404bea74236242ba3359b267fc930c7233b9fd086b0898825d0ac9"},
|
||||
{file = "mkdocs_material-9.5.36.tar.gz", hash = "sha256:140456f761320f72b399effc073fa3f8aac744c77b0970797c201cae2f6c967f"},
|
||||
{file = "mkdocs_material-9.5.37-py3-none-any.whl", hash = "sha256:6e8a986abad77be5edec3dd77cf1ddf2480963fb297a8e971f87a82fd464b070"},
|
||||
{file = "mkdocs_material-9.5.37.tar.gz", hash = "sha256:2c31607431ec234db124031255b0a9d4f3e1c3ecc2c47ad97ecfff0460471941"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3745,13 +3745,13 @@ sympy = "*"
|
||||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.47.1"
|
||||
version = "1.48.0"
|
||||
description = "The official Python library for the openai API"
|
||||
optional = false
|
||||
python-versions = ">=3.7.1"
|
||||
files = [
|
||||
{file = "openai-1.47.1-py3-none-any.whl", hash = "sha256:34277583bf268bb2494bc03f48ac123788c5e2a914db1d5a23d5edc29d35c825"},
|
||||
{file = "openai-1.47.1.tar.gz", hash = "sha256:62c8f5f478f82ffafc93b33040f8bb16a45948306198bd0cba2da2ecd9cf7323"},
|
||||
{file = "openai-1.48.0-py3-none-any.whl", hash = "sha256:7c4af223f0bf615ce4a12453729952c9a8b04ffe8c78aa77981b12fd970149cf"},
|
||||
{file = "openai-1.48.0.tar.gz", hash = "sha256:1d3b69ea62c287c4885a6f3ce840768564cd5f52c60ac5f890fef80d43cc4799"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4869,13 +4869,13 @@ torch = ["torch"]
|
||||
|
||||
[[package]]
|
||||
name = "pymdown-extensions"
|
||||
version = "10.10.1"
|
||||
version = "10.10.2"
|
||||
description = "Extension pack for Python Markdown."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pymdown_extensions-10.10.1-py3-none-any.whl", hash = "sha256:6c74ea6c2e2285186a241417480fc2d3cc52941b3ec2dced4014c84dc78c5493"},
|
||||
{file = "pymdown_extensions-10.10.1.tar.gz", hash = "sha256:ad277ee4739ced051c3b6328d22ce782358a3bec39bc6ca52815ccbf44f7acdc"},
|
||||
{file = "pymdown_extensions-10.10.2-py3-none-any.whl", hash = "sha256:513a9e9432b197cf0539356c8f1fc376e0d10b70ad150cadeb649a5628aacd45"},
|
||||
{file = "pymdown_extensions-10.10.2.tar.gz", hash = "sha256:65d82324ef2497931bc858c8320540c6264ab0d9a292707edb61f4fe0cd56633"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4943,21 +4943,23 @@ dev = ["build", "flake8", "mypy", "pytest", "twine"]
|
||||
|
||||
[[package]]
|
||||
name = "pyright"
|
||||
version = "1.1.381"
|
||||
version = "1.1.382.post0"
|
||||
description = "Command line wrapper for pyright"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pyright-1.1.381-py3-none-any.whl", hash = "sha256:5dc0aa80a265675d36abab59c674ae01dbe476714f91845b61b841d34aa99081"},
|
||||
{file = "pyright-1.1.381.tar.gz", hash = "sha256:314cf0c1351c189524fb10c7ac20688ecd470e8cc505c394d642c9c80bf7c3a5"},
|
||||
{file = "pyright-1.1.382.post0-py3-none-any.whl", hash = "sha256:a82a20b6a6511d71c6c95de19c0f874f7e50a013f332e3799deaae66a4d237d1"},
|
||||
{file = "pyright-1.1.382.post0.tar.gz", hash = "sha256:4b84dd4439b0cbc662dff6aaf012cc0860f1c788932ac4c2a4b5d6c1280a5e20"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
nodeenv = ">=1.6.0"
|
||||
typing-extensions = ">=4.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["twine (>=3.4.1)"]
|
||||
all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"]
|
||||
dev = ["twine (>=3.4.1)"]
|
||||
nodejs = ["nodejs-wheel-binaries"]
|
||||
|
||||
[[package]]
|
||||
name = "pysbd"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "crewai"
|
||||
version = "0.63.6"
|
||||
version = "0.64.0"
|
||||
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
|
||||
authors = ["Joao Moura <joao@crewai.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -104,7 +104,7 @@ class Agent(BaseAgent):
|
||||
description="Keep messages under the context window size by summarizing content.",
|
||||
)
|
||||
max_iter: int = Field(
|
||||
default=15,
|
||||
default=20,
|
||||
description="Maximum number of iterations for an agent to execute a task before giving it's best answer",
|
||||
)
|
||||
max_retry_limit: int = Field(
|
||||
@@ -344,8 +344,9 @@ class Agent(BaseAgent):
|
||||
human_feedbacks = [
|
||||
i["human_feedback"] for i in data.get(agent_id, {}).values()
|
||||
]
|
||||
task_prompt += "You MUST follow these feedbacks: \n " + "\n - ".join(
|
||||
human_feedbacks
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n "
|
||||
+ "\n - ".join(human_feedbacks)
|
||||
)
|
||||
|
||||
return task_prompt
|
||||
@@ -354,8 +355,9 @@ class Agent(BaseAgent):
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += "You MUST follow these feedbacks: \n " + "\n - ".join(
|
||||
trained_data_output["suggestions"]
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
+ "\n - ".join(trained_data_output["suggestions"])
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
@@ -176,7 +176,11 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
source = [self.role, self.goal, self.backstory]
|
||||
source = [
|
||||
self._original_role or self.role,
|
||||
self._original_goal or self.goal,
|
||||
self._original_backstory or self.backstory,
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -6,6 +6,7 @@ from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.utilities.converter import ConverterError
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -22,6 +23,7 @@ class CrewAgentExecutorMixin:
|
||||
have_forced_answer: bool
|
||||
max_iter: int
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
def _should_force_answer(self) -> bool:
|
||||
"""Determine if a forced answer is required based on iteration count."""
|
||||
@@ -100,6 +102,12 @@ class CrewAgentExecutorMixin:
|
||||
|
||||
def _ask_human_input(self, final_answer: dict) -> str:
|
||||
"""Prompt human input for final decision making."""
|
||||
return input(
|
||||
self._i18n.slice("getting_input").format(final_answer=final_answer)
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
)
|
||||
|
||||
self._printer.print(
|
||||
content="\n\n=====\n## Please provide feedback on the Final Result and the Agent's actions:",
|
||||
color="bold_yellow",
|
||||
)
|
||||
return input()
|
||||
|
||||
@@ -67,8 +67,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.ask_for_human_input = False
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.have_forced_answer = False
|
||||
self.name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
self.llm.stop = self.stop
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
@@ -96,6 +98,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages.append(self._format_msg(f"Feedback: {human_feedback}"))
|
||||
formatted_answer = self._invoke_loop()
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
self._handle_crew_training_output(formatted_answer)
|
||||
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _invoke_loop(self, formatted_answer=None):
|
||||
@@ -148,6 +153,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
except OutputParserException as e:
|
||||
self.messages.append({"role": "user", "content": e.error})
|
||||
if self.iterations > self.log_error_after:
|
||||
self._printer.print(
|
||||
content=f"Error parsing LLM output, agent will retry: {e.error}",
|
||||
color="red",
|
||||
)
|
||||
return self._invoke_loop(formatted_answer)
|
||||
|
||||
except Exception as e:
|
||||
@@ -242,21 +252,21 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def _summarize_messages(self) -> None:
|
||||
messages_groups = []
|
||||
|
||||
for message in self.messages:
|
||||
content = message["content"]
|
||||
for i in range(0, len(content), 5000):
|
||||
messages_groups.append(content[i : i + 5000])
|
||||
cut_size = self.llm.get_context_window_size()
|
||||
for i in range(0, len(content), cut_size):
|
||||
messages_groups.append(content[i : i + cut_size])
|
||||
|
||||
summarized_contents = []
|
||||
for group in messages_groups:
|
||||
summary = self.llm.call(
|
||||
[
|
||||
self._format_msg(
|
||||
self._i18n.slices("summarizer_system_message"), role="system"
|
||||
self._i18n.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
self._format_msg(
|
||||
self._i18n.errors("sumamrize_instruction").format(group=group),
|
||||
self._i18n.slice("sumamrize_instruction").format(group=group),
|
||||
),
|
||||
],
|
||||
callbacks=self.callbacks,
|
||||
@@ -267,7 +277,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
self.messages = [
|
||||
self._format_msg(
|
||||
self._i18n.errors("summary").format(merged_summary=merged_summary)
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -294,24 +304,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
) -> None:
|
||||
"""Function to handle the process of the training data."""
|
||||
agent_id = str(self.agent.id)
|
||||
|
||||
if (
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
and not self.ask_for_human_input
|
||||
):
|
||||
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
if training_data.get(agent_id):
|
||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||
training_data[agent_id][self.crew._train_iteration][
|
||||
"improved_output"
|
||||
] = result.output
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).save(training_data)
|
||||
else:
|
||||
self._logger.log(
|
||||
"error",
|
||||
"Invalid crew or missing _train_iteration attribute.",
|
||||
color="red",
|
||||
)
|
||||
training_data[agent_id][self.crew._train_iteration][
|
||||
"improved_output"
|
||||
] = result.output
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).save(training_data)
|
||||
|
||||
if self.ask_for_human_input and human_feedback is not None:
|
||||
training_data = {
|
||||
|
||||
@@ -11,6 +11,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
|
||||
from .authentication.main import AuthenticationCommand
|
||||
from .deploy.main import DeployCommand
|
||||
from .tools.main import ToolCommand
|
||||
from .evaluate_crew import evaluate_crew
|
||||
from .install_crew import install_crew
|
||||
from .replay_from_task import replay_task_command
|
||||
@@ -202,6 +203,12 @@ def deploy():
|
||||
pass
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
@@ -249,5 +256,20 @@ def deploy_remove(uuid: Optional[str]):
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.install(handle)
|
||||
|
||||
|
||||
@tool.command(name="publish")
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.publish(is_public)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
40
src/crewai/cli/command.py
Normal file
40
src/crewai/cli/command.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, Any
|
||||
from rich.console import Console
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.utils import get_auth_token
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self):
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry):
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
except Exception:
|
||||
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||
raise SystemExit
|
||||
|
||||
def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The JSON response containing error information.
|
||||
"""
|
||||
error = json_response.get("error", "Unknown error")
|
||||
message = json_response.get("message", "No message provided")
|
||||
console.print(f"Error: {error}", style="bold red")
|
||||
console.print(f"Message: {message}", style="bold red")
|
||||
@@ -1,66 +0,0 @@
|
||||
from os import getenv
|
||||
|
||||
import requests
|
||||
|
||||
from crewai.cli.deploy.utils import get_crewai_version
|
||||
|
||||
|
||||
class CrewAPI:
|
||||
"""
|
||||
CrewAPI class to interact with the crewAI+ API.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
|
||||
}
|
||||
self.base_url = getenv(
|
||||
"CREWAI_BASE_URL", "https://app.crewai.com/crewai_plus/api/v1/crews"
|
||||
)
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||||
url = f"{self.base_url}/{endpoint}"
|
||||
return requests.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
# Deploy
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request("POST", f"by-name/{project_name}/deploy")
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("POST", f"{uuid}/deploy")
|
||||
|
||||
# Status
|
||||
def status_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request("GET", f"by-name/{project_name}/status")
|
||||
|
||||
def status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{uuid}/status")
|
||||
|
||||
# Logs
|
||||
def logs_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request("GET", f"by-name/{project_name}/logs/{log_type}")
|
||||
|
||||
def logs_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request("GET", f"{uuid}/logs/{log_type}")
|
||||
|
||||
# Delete
|
||||
def delete_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request("DELETE", f"by-name/{project_name}")
|
||||
|
||||
def delete_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("DELETE", f"{uuid}")
|
||||
|
||||
# List
|
||||
def list_crews(self) -> requests.Response:
|
||||
return self._make_request("GET", "")
|
||||
|
||||
# Create
|
||||
def create_crew(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", "", json=payload)
|
||||
@@ -2,11 +2,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.telemetry import Telemetry
|
||||
from .api import CrewAPI
|
||||
from .utils import (
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.utils import (
|
||||
fetch_and_json_env_file,
|
||||
get_auth_token,
|
||||
get_git_remote_url,
|
||||
get_project_name,
|
||||
)
|
||||
@@ -14,7 +12,7 @@ from .utils import (
|
||||
console = Console()
|
||||
|
||||
|
||||
class DeployCommand:
|
||||
class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
@@ -23,40 +21,10 @@ class DeployCommand:
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
try:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
access_token = get_auth_token()
|
||||
except Exception:
|
||||
self._deploy_signup_error_span = self._telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||
raise SystemExit
|
||||
|
||||
self.project_name = get_project_name()
|
||||
if self.project_name is None:
|
||||
console.print(
|
||||
"No project name found. Please ensure your project has a valid pyproject.toml file.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
self.client = CrewAPI(api_key=access_token)
|
||||
|
||||
def _handle_error(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The JSON response containing error information.
|
||||
"""
|
||||
error = json_response.get("error", "Unknown error")
|
||||
message = json_response.get("message", "No message provided")
|
||||
console.print(f"Error: {error}", style="bold red")
|
||||
console.print(f"Message: {message}", style="bold red")
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
self.project_name = get_project_name(require=True)
|
||||
|
||||
def _standard_no_param_error_message(self) -> None:
|
||||
"""
|
||||
@@ -104,9 +72,9 @@ class DeployCommand:
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.client.deploy_by_uuid(uuid)
|
||||
response = self.plus_api_client.deploy_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.client.deploy_by_name(self.project_name)
|
||||
response = self.plus_api_client.deploy_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
@@ -115,7 +83,7 @@ class DeployCommand:
|
||||
if response.status_code == 200:
|
||||
self._display_deployment_info(json_response)
|
||||
else:
|
||||
self._handle_error(json_response)
|
||||
self._handle_plus_api_error(json_response)
|
||||
|
||||
def create_crew(self, confirm: bool = False) -> None:
|
||||
"""
|
||||
@@ -139,11 +107,11 @@ class DeployCommand:
|
||||
self._confirm_input(env_vars, remote_repo_url, confirm)
|
||||
payload = self._create_payload(env_vars, remote_repo_url)
|
||||
|
||||
response = self.client.create_crew(payload)
|
||||
response = self.plus_api_client.create_crew(payload)
|
||||
if response.status_code == 201:
|
||||
self._display_creation_success(response.json())
|
||||
else:
|
||||
self._handle_error(response.json())
|
||||
self._handle_plus_api_error(response.json())
|
||||
|
||||
def _confirm_input(
|
||||
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
|
||||
@@ -208,7 +176,7 @@ class DeployCommand:
|
||||
"""
|
||||
console.print("Listing all Crews\n", style="bold blue")
|
||||
|
||||
response = self.client.list_crews()
|
||||
response = self.plus_api_client.list_crews()
|
||||
json_response = response.json()
|
||||
if response.status_code == 200:
|
||||
self._display_crews(json_response)
|
||||
@@ -243,9 +211,9 @@ class DeployCommand:
|
||||
"""
|
||||
console.print("Fetching deployment status...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.client.status_by_uuid(uuid)
|
||||
response = self.plus_api_client.crew_status_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.client.status_by_name(self.project_name)
|
||||
response = self.plus_api_client.crew_status_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
@@ -254,7 +222,7 @@ class DeployCommand:
|
||||
if response.status_code == 200:
|
||||
self._display_crew_status(json_response)
|
||||
else:
|
||||
self._handle_error(json_response)
|
||||
self._handle_plus_api_error(json_response)
|
||||
|
||||
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
|
||||
"""
|
||||
@@ -278,9 +246,9 @@ class DeployCommand:
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
response = self.client.logs_by_uuid(uuid, log_type)
|
||||
response = self.plus_api_client.crew_by_uuid(uuid, log_type)
|
||||
elif self.project_name:
|
||||
response = self.client.logs_by_name(self.project_name, log_type)
|
||||
response = self.plus_api_client.crew_by_name(self.project_name, log_type)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
@@ -288,7 +256,7 @@ class DeployCommand:
|
||||
if response.status_code == 200:
|
||||
self._display_logs(response.json())
|
||||
else:
|
||||
self._handle_error(response.json())
|
||||
self._handle_plus_api_error(response.json())
|
||||
|
||||
def remove_crew(self, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
@@ -301,9 +269,9 @@ class DeployCommand:
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
response = self.client.delete_by_uuid(uuid)
|
||||
response = self.plus_api_client.delete_crew_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.client.delete_by_name(self.project_name)
|
||||
response = self.plus_api_client.delete_crew_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import sys
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from ..authentication.utils import TokenManager
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
|
||||
|
||||
# Drop the simple_toml_parser when we move to python3.11
|
||||
def simple_toml_parser(content):
|
||||
result = {}
|
||||
current_section = result
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('[') and line.endswith(']'):
|
||||
# New section
|
||||
section = line[1:-1].split('.')
|
||||
current_section = result
|
||||
for key in section:
|
||||
current_section = current_section.setdefault(key, {})
|
||||
elif '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"')
|
||||
current_section[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
if sys.version_info >= (3, 11):
|
||||
return tomllib.loads(content)
|
||||
else:
|
||||
return simple_toml_parser(content)
|
||||
|
||||
|
||||
def get_git_remote_url() -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
# Run the git remote -v command
|
||||
result = subprocess.run(
|
||||
["git", "remote", "-v"], capture_output=True, text=True, check=True
|
||||
)
|
||||
|
||||
# Get the output
|
||||
output = result.stdout
|
||||
|
||||
# Parse the output to find the origin URL
|
||||
matches = re.findall(r"origin\s+(.*?)\s+\(fetch\)", output)
|
||||
|
||||
if matches:
|
||||
return matches[0] # Return the first match (origin URL)
|
||||
else:
|
||||
console.print("No origin remote found.", style="bold red")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"Error running trying to fetch the Git Repository: {e}", style="bold red")
|
||||
except FileNotFoundError:
|
||||
console.print("Git command not found. Make sure Git is installed and in your PATH.", style="bold red")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None:
|
||||
"""Get the project name from the pyproject.toml file."""
|
||||
try:
|
||||
# Read the pyproject.toml file
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
# Extract the project name
|
||||
project_name = pyproject_content["tool"]["poetry"]["name"]
|
||||
|
||||
if "crewai" not in pyproject_content["tool"]["poetry"]["dependencies"]:
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
|
||||
return project_name
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {pyproject_path} not found.")
|
||||
except KeyError:
|
||||
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||
print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file."
|
||||
if sys.version_info >= (3, 11)
|
||||
else f"Error reading the pyproject.toml file: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error reading the pyproject.toml file: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:
|
||||
"""Get the version number of crewai from the poetry.lock file."""
|
||||
try:
|
||||
with open(poetry_lock_path, "r") as f:
|
||||
lock_content = f.read()
|
||||
|
||||
match = re.search(
|
||||
r'\[\[package\]\]\s*name\s*=\s*"crewai"\s*version\s*=\s*"([^"]+)"',
|
||||
lock_content,
|
||||
re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
print("crewai package not found in poetry.lock")
|
||||
return "no-version-found"
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {poetry_lock_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the poetry.lock file: {e}")
|
||||
|
||||
return "no-version-found"
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
env_dict = {}
|
||||
for line in env_content.splitlines():
|
||||
if line.strip() and not line.strip().startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
env_dict[key.strip()] = value.strip()
|
||||
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {env_file_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the .env file: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception()
|
||||
return access_token
|
||||
92
src/crewai/cli/plus_api.py
Normal file
92
src/crewai/cli/plus_api.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Optional
|
||||
import requests
|
||||
from os import getenv
|
||||
from crewai.cli.utils import get_crewai_version
|
||||
from urllib.parse import urljoin
|
||||
|
||||
|
||||
class PlusAPI:
|
||||
"""
|
||||
This class exposes methods for working with the CrewAI+ API.
|
||||
"""
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
|
||||
"X-Crewai-Version": get_crewai_version(),
|
||||
}
|
||||
self.base_url = getenv("CREWAI_BASE_URL", "https://app.crewai.com")
|
||||
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
return requests.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
def get_tool(self, handle: str):
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
def publish_tool(
|
||||
self,
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
encoded_file: str,
|
||||
):
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": is_public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
|
||||
|
||||
def crew_status_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
|
||||
)
|
||||
|
||||
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
|
||||
|
||||
def crew_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def crew_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def delete_crew_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
|
||||
)
|
||||
|
||||
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
|
||||
|
||||
def list_crews(self) -> requests.Response:
|
||||
return self._make_request("GET", self.CREWS_RESOURCE)
|
||||
|
||||
def create_crew(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.63.6,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.64.0,<1.0.0" }
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.63.6,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.64.0,<1.0.0" }
|
||||
asyncio = "*"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
@@ -6,7 +6,7 @@ authors = ["Your Name <you@example.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<=3.13"
|
||||
crewai = { extras = ["tools"], version = ">=0.63.6,<1.0.0" }
|
||||
crewai = { extras = ["tools"], version = ">=0.64.0,<1.0.0" }
|
||||
|
||||
|
||||
[tool.poetry.scripts]
|
||||
|
||||
0
src/crewai/cli/tools/__init__.py
Normal file
0
src/crewai/cli/tools/__init__.py
Normal file
168
src/crewai/cli/tools/main.py
Normal file
168
src/crewai/cli/tools/main.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import base64
|
||||
import click
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.utils import (
|
||||
get_project_name,
|
||||
get_project_description,
|
||||
get_project_version,
|
||||
)
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle tool repository related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def publish(self, is_public: bool):
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
|
||||
project_version = get_project_version(require=True)
|
||||
assert isinstance(project_version, str)
|
||||
|
||||
project_description = get_project_description(require=False)
|
||||
encoded_tarball = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
subprocess.run(
|
||||
["poetry", "build", "-f", "sdist", "--output", temp_build_dir],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
|
||||
tarball_filename = next(
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
|
||||
)
|
||||
if not tarball_filename:
|
||||
console.print(
|
||||
"Project build failed. Please ensure that the command `poetry build -f sdist` completes successfully.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
tarball_path = os.path.join(temp_build_dir, tarball_filename)
|
||||
with open(tarball_path, "rb") as file:
|
||||
tarball_contents = file.read()
|
||||
|
||||
encoded_tarball = base64.b64encode(tarball_contents).decode("utf-8")
|
||||
|
||||
publish_response = self.plus_api_client.publish_tool(
|
||||
handle=project_name,
|
||||
is_public=is_public,
|
||||
version=project_version,
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
)
|
||||
if publish_response.status_code == 422:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool. Please fix the following errors:[/bold red]"
|
||||
)
|
||||
for field, messages in publish_response.json().items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
|
||||
raise SystemExit
|
||||
elif publish_response.status_code != 200:
|
||||
self._handle_plus_api_error(publish_response.json())
|
||||
console.print(
|
||||
"Failed to publish tool. Please try again later.", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
published_handle = publish_response.json()["handle"]
|
||||
console.print(
|
||||
f"Succesfully published {published_handle} ({project_version}).\nInstall it in other projects with crewai tool install {published_handle}",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
def install(self, handle: str):
|
||||
get_response = self.plus_api_client.get_tool(handle)
|
||||
|
||||
if get_response.status_code == 404:
|
||||
console.print(
|
||||
"No tool found with this name. Please ensure the tool was published and you have access to it.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
elif get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
self._add_repository_to_poetry(get_response.json())
|
||||
self._add_package(get_response.json())
|
||||
|
||||
console.print(f"Succesfully installed {handle}", style="bold green")
|
||||
|
||||
def _add_repository_to_poetry(self, tool_details):
|
||||
repository_handle = f"crewai-{tool_details['repository']['handle']}"
|
||||
repository_url = tool_details["repository"]["url"]
|
||||
repository_credentials = tool_details["repository"]["credentials"]
|
||||
|
||||
add_repository_command = [
|
||||
"poetry",
|
||||
"source",
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
repository_handle,
|
||||
repository_url,
|
||||
]
|
||||
add_repository_result = subprocess.run(
|
||||
add_repository_command, text=True, check=True
|
||||
)
|
||||
|
||||
if add_repository_result.stderr:
|
||||
click.echo(add_repository_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
add_repository_credentials_command = [
|
||||
"poetry",
|
||||
"config",
|
||||
f"http-basic.{repository_handle}",
|
||||
repository_credentials,
|
||||
'""',
|
||||
]
|
||||
add_repository_credentials_result = subprocess.run(
|
||||
add_repository_credentials_command,
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if add_repository_credentials_result.stderr:
|
||||
click.echo(add_repository_credentials_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _add_package(self, tool_details):
|
||||
tool_handle = tool_details["handle"]
|
||||
repository_handle = tool_details["repository"]["handle"]
|
||||
pypi_index_handle = f"crewai-{repository_handle}"
|
||||
|
||||
add_package_command = [
|
||||
"poetry",
|
||||
"add",
|
||||
"--source",
|
||||
pypi_index_handle,
|
||||
tool_handle,
|
||||
]
|
||||
add_package_result = subprocess.run(
|
||||
add_package_command, capture_output=False, text=True, check=True
|
||||
)
|
||||
|
||||
if add_package_result.stderr:
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
@@ -1,4 +1,17 @@
|
||||
import click
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from functools import reduce
|
||||
from rich.console import Console
|
||||
from typing import Any, Dict, List
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def copy_template(src, dst, name, class_name, folder_name):
|
||||
@@ -16,3 +29,191 @@ def copy_template(src, dst, name, class_name, folder_name):
|
||||
file.write(content)
|
||||
|
||||
click.secho(f" - Created {dst}", fg="green")
|
||||
|
||||
|
||||
# Drop the simple_toml_parser when we move to python3.11
|
||||
def simple_toml_parser(content):
|
||||
result = {}
|
||||
current_section = result
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("[") and line.endswith("]"):
|
||||
# New section
|
||||
section = line[1:-1].split(".")
|
||||
current_section = result
|
||||
for key in section:
|
||||
current_section = current_section.setdefault(key, {})
|
||||
elif "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"')
|
||||
current_section[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def parse_toml(content):
|
||||
if sys.version_info >= (3, 11):
|
||||
return tomllib.loads(content)
|
||||
else:
|
||||
return simple_toml_parser(content)
|
||||
|
||||
|
||||
def get_git_remote_url() -> str | None:
|
||||
"""Get the Git repository's remote URL."""
|
||||
try:
|
||||
# Run the git remote -v command
|
||||
result = subprocess.run(
|
||||
["git", "remote", "-v"], capture_output=True, text=True, check=True
|
||||
)
|
||||
|
||||
# Get the output
|
||||
output = result.stdout
|
||||
|
||||
# Parse the output to find the origin URL
|
||||
matches = re.findall(r"origin\s+(.*?)\s+\(fetch\)", output)
|
||||
|
||||
if matches:
|
||||
return matches[0] # Return the first match (origin URL)
|
||||
else:
|
||||
console.print("No origin remote found.", style="bold red")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"Error running trying to fetch the Git Repository: {e}", style="bold red"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print(
|
||||
"Git command not found. Make sure Git is installed and in your PATH.",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_project_name(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project name from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["tool", "poetry", "name"], require=require
|
||||
)
|
||||
|
||||
|
||||
def get_project_version(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project version from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["tool", "poetry", "version"], require=require
|
||||
)
|
||||
|
||||
|
||||
def get_project_description(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project description from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["tool", "poetry", "description"], require=require
|
||||
)
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: List[str], require: bool
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["tool", "poetry", "dependencies"])
|
||||
or {}
|
||||
)
|
||||
if "crewai" not in dependencies:
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
|
||||
attribute = _get_nested_value(pyproject_content, keys)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {pyproject_path} not found.")
|
||||
except KeyError:
|
||||
print(f"Error: {pyproject_path} is not a valid pyproject.toml file.")
|
||||
except tomllib.TOMLDecodeError if sys.version_info >= (3, 11) else Exception as e: # type: ignore
|
||||
print(
|
||||
f"Error: {pyproject_path} is not a valid TOML file."
|
||||
if sys.version_info >= (3, 11)
|
||||
else f"Error reading the pyproject.toml file: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error reading the pyproject.toml file: {e}")
|
||||
|
||||
if require and not attribute:
|
||||
console.print(
|
||||
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:
|
||||
"""Get the version number of crewai from the poetry.lock file."""
|
||||
try:
|
||||
with open(poetry_lock_path, "r") as f:
|
||||
lock_content = f.read()
|
||||
|
||||
match = re.search(
|
||||
r'\[\[package\]\]\s*name\s*=\s*"crewai"\s*version\s*=\s*"([^"]+)"',
|
||||
lock_content,
|
||||
re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
print("crewai package not found in poetry.lock")
|
||||
return "no-version-found"
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {poetry_lock_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the poetry.lock file: {e}")
|
||||
|
||||
return "no-version-found"
|
||||
|
||||
|
||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||
try:
|
||||
# Read the .env file
|
||||
with open(env_file_path, "r") as f:
|
||||
env_content = f.read()
|
||||
|
||||
# Parse the .env file content to a dictionary
|
||||
env_dict = {}
|
||||
for line in env_content.splitlines():
|
||||
if line.strip() and not line.strip().startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
env_dict[key.strip()] = value.strip()
|
||||
|
||||
return env_dict
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {env_file_path} not found.")
|
||||
except Exception as e:
|
||||
print(f"Error reading the .env file: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception()
|
||||
return access_token
|
||||
|
||||
@@ -1,8 +1,75 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import logging
|
||||
import warnings
|
||||
import litellm
|
||||
from litellm import get_supported_openai_params
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
import sys
|
||||
import io
|
||||
|
||||
|
||||
class FilteredStream(io.StringIO):
|
||||
def write(self, s):
|
||||
if (
|
||||
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
||||
in s
|
||||
or "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True`"
|
||||
in s
|
||||
):
|
||||
return
|
||||
super().write(s)
|
||||
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES = {
|
||||
# openai
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
# deepseek
|
||||
"deepseek-chat": 128000,
|
||||
# groq
|
||||
"gemma2-9b-it": 8192,
|
||||
"gemma-7b-it": 8192,
|
||||
"llama3-groq-70b-8192-tool-use-preview": 8192,
|
||||
"llama3-groq-8b-8192-tool-use-preview": 8192,
|
||||
"llama-3.1-70b-versatile": 131072,
|
||||
"llama-3.1-8b-instant": 131072,
|
||||
"llama-3.2-1b-preview": 8192,
|
||||
"llama-3.2-3b-preview": 8192,
|
||||
"llama-3.2-11b-text-preview": 8192,
|
||||
"llama-3.2-90b-text-preview": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
"llama3-8b-8192": 8192,
|
||||
"mixtral-8x7b-32768": 32768,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# Redirect stdout and stderr
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = FilteredStream()
|
||||
sys.stderr = FilteredStream()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(
|
||||
@@ -50,43 +117,50 @@ class LLM:
|
||||
self.kwargs = kwargs
|
||||
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = False
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
if callbacks and len(callbacks) > 0:
|
||||
litellm.callbacks = callbacks
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
**self.kwargs,
|
||||
}
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise # Re-raise the exception after logging
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
|
||||
raise # Re-raise the exception after logging
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
@@ -103,3 +177,7 @@ class LLM:
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||
return int(LLM_CONTEXT_WINDOW_SIZES.get(self.model, 8192) * 0.75)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable, Dict, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.crew import Crew
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def CrewBase(cls):
|
||||
T = TypeVar("T", bound=Type[Any])
|
||||
|
||||
|
||||
def CrewBase(cls: T) -> T:
|
||||
class WrappedClass(cls):
|
||||
is_crew_class: bool = True # type: ignore
|
||||
|
||||
@@ -32,6 +37,19 @@ def CrewBase(cls):
|
||||
self.map_all_agent_variables()
|
||||
self.map_all_task_variables()
|
||||
|
||||
def crew(self) -> "Crew":
|
||||
agents = [
|
||||
getattr(self, name)()
|
||||
for name, func in self._get_all_functions().items()
|
||||
if hasattr(func, "is_agent")
|
||||
]
|
||||
tasks = [
|
||||
getattr(self, name)()
|
||||
for name, func in reversed(self._get_all_functions().items())
|
||||
if hasattr(func, "is_task")
|
||||
]
|
||||
return Crew(agents=agents, tasks=tasks)
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config_path: Path):
|
||||
try:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"task_with_context": "{task}\n\nThis is the context you're working with:\n{context}",
|
||||
"expected_output": "\nThis is the expect criteria for your final answer: {expected_output}\nyou MUST return the actual complete content as the final answer, not a summary.",
|
||||
"human_feedback": "You got human feedback on your work, re-evaluate it and give a new Final Answer when ready.\n {human_feedback}",
|
||||
"getting_input": "This is the agent's final answer: {final_answer}\nPlease provide feedback: ",
|
||||
"getting_input": "This is the agent's final answer: {final_answer}\n\n",
|
||||
"summarizer_system_message": "You are a helpful assistant that summarizes text.",
|
||||
"sumamrize_instruction": "Summarize the following text, make sure to include all the important information: {group}",
|
||||
"summary": "This is a summary of our conversation so far:\n{merged_summary}"
|
||||
|
||||
@@ -49,7 +49,7 @@ class TaskEvaluation(BaseModel):
|
||||
|
||||
class TrainingTaskEvaluation(BaseModel):
|
||||
suggestions: List[str] = Field(
|
||||
description="Based on the Human Feedbacks and the comparison between Initial Outputs and Improved outputs provide action items based on human_feedback for future tasks."
|
||||
description="List of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specific points from the human feedback are incorporated into these instructions."
|
||||
)
|
||||
quality: float = Field(
|
||||
description="A score from 0 to 10 evaluating on completion, quality, and overall performance from the improved output to the initial output based on the human feedback."
|
||||
@@ -116,7 +116,7 @@ class TaskEvaluator:
|
||||
"Assess the quality of the training data based on the llm output, human feedback , and llm output improved result.\n\n"
|
||||
f"{final_aggregated_data}"
|
||||
"Please provide:\n"
|
||||
"- Based on the Human Feedbacks and the comparison between Initial Outputs and Improved outputs provide action items based on human_feedback for future tasks\n"
|
||||
"- Provide a list of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific action items for future tasks. Ensure all key and specificpoints from the human feedback are incorporated into these instructions.\n"
|
||||
"- A score from 0 to 10 evaluating on completion, quality, and overall performance from the improved output to the initial output based on the human feedback\n"
|
||||
)
|
||||
instructions = "I'm gonna convert this raw text into valid JSON."
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
class LLMContextLengthExceededException(Exception):
|
||||
CONTEXT_LIMIT_ERRORS = [
|
||||
"expected a string with maximum length",
|
||||
"maximum context length",
|
||||
"context length exceeded",
|
||||
"context_length_exceeded",
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
import unittest
|
||||
from os import environ
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.cli.deploy.api import CrewAPI
|
||||
|
||||
|
||||
class TestCrewAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_api_key"
|
||||
self.api = CrewAPI(self.api_key)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.api.api_key, self.api_key)
|
||||
self.assertEqual(
|
||||
self.api.headers,
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CrewAI-CLI/no-version-found"
|
||||
},
|
||||
)
|
||||
|
||||
@patch("crewai.cli.deploy.api.requests.request")
|
||||
def test_make_request(self, mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
response = self.api._make_request("GET", "test_endpoint")
|
||||
|
||||
mock_request.assert_called_once_with(
|
||||
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_deploy_by_name(self, mock_make_request):
|
||||
self.api.deploy_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with("POST", "by-name/test_project/deploy")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_deploy_by_uuid(self, mock_make_request):
|
||||
self.api.deploy_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with("POST", "test_uuid/deploy")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_status_by_name(self, mock_make_request):
|
||||
self.api.status_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with("GET", "by-name/test_project/status")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_status_by_uuid(self, mock_make_request):
|
||||
self.api.status_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with("GET", "test_uuid/status")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_logs_by_name(self, mock_make_request):
|
||||
self.api.logs_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "by-name/test_project/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.logs_by_name("test_project", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "by-name/test_project/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_logs_by_uuid(self, mock_make_request):
|
||||
self.api.logs_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with("GET", "test_uuid/logs/deployment")
|
||||
|
||||
self.api.logs_by_uuid("test_uuid", "custom_log")
|
||||
mock_make_request.assert_called_with("GET", "test_uuid/logs/custom_log")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_delete_by_name(self, mock_make_request):
|
||||
self.api.delete_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with("DELETE", "by-name/test_project")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_delete_by_uuid(self, mock_make_request):
|
||||
self.api.delete_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with("DELETE", "test_uuid")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_list_crews(self, mock_make_request):
|
||||
self.api.list_crews()
|
||||
mock_make_request.assert_called_once_with("GET", "")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
def test_create_crew(self, mock_make_request):
|
||||
payload = {"name": "test_crew"}
|
||||
self.api.create_crew(payload)
|
||||
mock_make_request.assert_called_once_with("POST", "", json=payload)
|
||||
|
||||
@patch.dict(environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
|
||||
def test_custom_base_url(self):
|
||||
custom_api = CrewAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url.com/api",
|
||||
)
|
||||
@@ -4,37 +4,38 @@ from unittest.mock import MagicMock, patch
|
||||
import sys
|
||||
|
||||
from crewai.cli.deploy.main import DeployCommand
|
||||
from crewai.cli.deploy.utils import parse_toml
|
||||
from crewai.cli.utils import parse_toml
|
||||
|
||||
|
||||
class TestDeployCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.deploy.main.get_auth_token")
|
||||
@patch("crewai.cli.command.get_auth_token")
|
||||
@patch("crewai.cli.deploy.main.get_project_name")
|
||||
@patch("crewai.cli.deploy.main.CrewAPI")
|
||||
def setUp(self, mock_crew_api, mock_get_project_name, mock_get_auth_token):
|
||||
@patch("crewai.cli.command.PlusAPI")
|
||||
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
|
||||
self.mock_get_auth_token = mock_get_auth_token
|
||||
self.mock_get_project_name = mock_get_project_name
|
||||
self.mock_crew_api = mock_crew_api
|
||||
self.mock_plus_api = mock_plus_api
|
||||
|
||||
self.mock_get_auth_token.return_value = "test_token"
|
||||
self.mock_get_project_name.return_value = "test_project"
|
||||
|
||||
self.deploy_command = DeployCommand()
|
||||
self.mock_client = self.deploy_command.client
|
||||
self.mock_client = self.deploy_command.plus_api_client
|
||||
|
||||
def test_init_success(self):
|
||||
self.assertEqual(self.deploy_command.project_name, "test_project")
|
||||
self.mock_crew_api.assert_called_once_with(api_key="test_token")
|
||||
self.mock_plus_api.assert_called_once_with(api_key="test_token")
|
||||
|
||||
@patch("crewai.cli.deploy.main.get_auth_token")
|
||||
@patch("crewai.cli.command.get_auth_token")
|
||||
def test_init_failure(self, mock_get_auth_token):
|
||||
mock_get_auth_token.side_effect = Exception("Auth failed")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
DeployCommand()
|
||||
|
||||
def test_handle_error(self):
|
||||
def test_handle_plus_api_error(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._handle_error(
|
||||
self.deploy_command._handle_plus_api_error(
|
||||
{"error": "Test error", "message": "Test message"}
|
||||
)
|
||||
self.assertIn("Error: Test error", fake_out.getvalue())
|
||||
@@ -121,7 +122,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "TestCrew", "status": "active"}
|
||||
self.mock_client.status_by_name.return_value = mock_response
|
||||
self.mock_client.crew_status_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_status()
|
||||
@@ -135,7 +136,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
{"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"},
|
||||
{"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"},
|
||||
]
|
||||
self.mock_client.logs_by_name.return_value = mock_response
|
||||
self.mock_client.crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_logs(None)
|
||||
@@ -145,7 +146,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
def test_remove_crew(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 204
|
||||
self.mock_client.delete_by_name.return_value = mock_response
|
||||
self.mock_client.delete_crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.remove_crew(None)
|
||||
@@ -165,9 +166,12 @@ class TestDeployCommand(unittest.TestCase):
|
||||
crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" }
|
||||
"""
|
||||
parsed = parse_toml(toml_content)
|
||||
self.assertEqual(parsed['tool']['poetry']['name'], 'test_project')
|
||||
self.assertEqual(parsed["tool"]["poetry"]["name"], "test_project")
|
||||
|
||||
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="""
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[tool.poetry]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
@@ -175,14 +179,19 @@ class TestDeployCommand(unittest.TestCase):
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" }
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def test_get_project_name_python_310(self, mock_open):
|
||||
from crewai.cli.deploy.utils import get_project_name
|
||||
from crewai.cli.utils import get_project_name
|
||||
|
||||
project_name = get_project_name()
|
||||
self.assertEqual(project_name, 'test_project')
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
|
||||
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="""
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[tool.poetry]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
@@ -190,13 +199,18 @@ class TestDeployCommand(unittest.TestCase):
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" }
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def test_get_project_name_python_311_plus(self, mock_open):
|
||||
from crewai.cli.deploy.utils import get_project_name
|
||||
project_name = get_project_name()
|
||||
self.assertEqual(project_name, 'test_project')
|
||||
from crewai.cli.utils import get_project_name
|
||||
|
||||
@patch('builtins.open', new_callable=unittest.mock.mock_open, read_data="""
|
||||
project_name = get_project_name()
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[[package]]
|
||||
name = "crewai"
|
||||
version = "0.51.1"
|
||||
@@ -204,16 +218,19 @@ class TestDeployCommand(unittest.TestCase):
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.10,<4.0"
|
||||
""")
|
||||
""",
|
||||
)
|
||||
def test_get_crewai_version(self, mock_open):
|
||||
from crewai.cli.deploy.utils import get_crewai_version
|
||||
version = get_crewai_version()
|
||||
self.assertEqual(version, '0.51.1')
|
||||
from crewai.cli.utils import get_crewai_version
|
||||
|
||||
@patch('builtins.open', side_effect=FileNotFoundError)
|
||||
version = get_crewai_version()
|
||||
self.assertEqual(version, "0.51.1")
|
||||
|
||||
@patch("builtins.open", side_effect=FileNotFoundError)
|
||||
def test_get_crewai_version_file_not_found(self, mock_open):
|
||||
from crewai.cli.deploy.utils import get_crewai_version
|
||||
with patch('sys.stdout', new=StringIO()) as fake_out:
|
||||
from crewai.cli.utils import get_crewai_version
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
version = get_crewai_version()
|
||||
self.assertEqual(version, 'no-version-found')
|
||||
self.assertEqual(version, "no-version-found")
|
||||
self.assertIn("Error: poetry.lock not found.", fake_out.getvalue())
|
||||
|
||||
185
tests/cli/test_plus_api.py
Normal file
185
tests/cli/test_plus_api.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
class TestPlusAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_api_key"
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.api.api_key, self.api_key)
|
||||
self.assertEqual(
|
||||
self.api.headers,
|
||||
{
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CrewAI-CLI/no-version-found",
|
||||
"X-Crewai-Version": "no-version-found",
|
||||
},
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_without_description(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = False
|
||||
version = "2.0.0"
|
||||
description = None
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.requests.request")
|
||||
def test_make_request(self, mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
response = self.api._make_request("GET", "test_endpoint")
|
||||
|
||||
mock_request.assert_called_once_with(
|
||||
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_name(self, mock_make_request):
|
||||
self.api.deploy_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_uuid(self, mock_make_request):
|
||||
self.api.deploy_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_name(self, mock_make_request):
|
||||
self.api.crew_status_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_uuid(self, mock_make_request):
|
||||
self.api.crew_status_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_name(self, mock_make_request):
|
||||
self.api.crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.crew_by_name("test_project", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_uuid(self, mock_make_request):
|
||||
self.api.crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.crew_by_uuid("test_uuid", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_name(self, mock_make_request):
|
||||
self.api.delete_crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_uuid(self, mock_make_request):
|
||||
self.api.delete_crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_list_crews(self, mock_make_request):
|
||||
self.api.list_crews()
|
||||
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_create_crew(self, mock_make_request):
|
||||
payload = {"name": "test_crew"}
|
||||
self.api.create_crew(payload)
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
|
||||
def test_custom_base_url(self):
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url.com/api",
|
||||
)
|
||||
0
tests/cli/tools/__init__.py
Normal file
0
tests/cli/tools/__init__.py
Normal file
229
tests/cli/tools/test_main.py
Normal file
229
tests/cli/tools/test_main.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from io import StringIO
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestToolCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(self, mock_get, mock_subprocess_run):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"handle": "sample-tool",
|
||||
"repository": {
|
||||
"handle": "sample-repo",
|
||||
"url": "https://example.com/repo",
|
||||
"credentials": "my_very_secret",
|
||||
},
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("sample-tool")
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"source",
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
"crewai-sample-repo",
|
||||
"https://example.com/repo",
|
||||
],
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"config",
|
||||
"http-basic.crewai-sample-repo",
|
||||
"my_very_secret",
|
||||
'""',
|
||||
],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
self.assertIn("Succesfully installed sample-tool", output)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(self, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
self.assertIn("No tool found with this name", output)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(self, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
self.assertIn("Failed to get tool details", output)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_success(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_once_with(require=True)
|
||||
mock_get_project_version.assert_called_once_with(require=True)
|
||||
mock_get_project_description.assert_called_once_with(require=False)
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
mock_open.assert_called_once_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_once_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_failure(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
self.assertIn("Failed to publish tool", output)
|
||||
self.assertIn("Name is already taken", output)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_api_error(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 500
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
self.assertIn("Failed to publish tool", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2335,6 +2335,50 @@ def test_key():
|
||||
assert crew.key == hash
|
||||
|
||||
|
||||
def test_key_with_interpolated_inputs():
|
||||
researcher = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Make the best research and analysis on content {topic}",
|
||||
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="{topic} Senior Writer",
|
||||
goal="Write the best content about {topic}",
|
||||
backstory="You're a senior writer, specialized in technology, software engineering, AI and startups. You work as a freelancer and are now working on writing content for a new customer.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
|
||||
expected_output="Bullet point list of 5 important events.",
|
||||
agent=researcher,
|
||||
),
|
||||
Task(
|
||||
description="Write a 1 amazing paragraph highlight for each idea of {topic} that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="A 4 paragraph article about AI.",
|
||||
agent=writer,
|
||||
),
|
||||
]
|
||||
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
process=Process.sequential,
|
||||
tasks=tasks,
|
||||
)
|
||||
hash = hashlib.md5(
|
||||
f"{researcher.key}|{writer.key}|{tasks[0].key}|{tasks[1].key}".encode()
|
||||
).hexdigest()
|
||||
|
||||
assert crew.key == hash
|
||||
|
||||
curr_key = crew.key
|
||||
crew._interpolate_inputs({"topic": "AI"})
|
||||
assert crew.key == curr_key
|
||||
|
||||
|
||||
def test_conditional_task_requirement_breaks_when_singular_conditional_task():
|
||||
def condition_fn(output) -> bool:
|
||||
return output.raw.startswith("Andrew Ng has!!")
|
||||
|
||||
Reference in New Issue
Block a user