mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 09:38:17 +00:00
adding tests. cleaing up pr.
This commit is contained in:
@@ -8,7 +8,6 @@ from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
|||||||
from crewai.cli.create_crew import create_crew
|
from crewai.cli.create_crew import create_crew
|
||||||
from crewai.cli.create_flow import create_flow
|
from crewai.cli.create_flow import create_flow
|
||||||
from crewai.cli.crew_chat import run_chat
|
from crewai.cli.crew_chat import run_chat
|
||||||
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
|
||||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||||
KickoffTaskOutputsSQLiteStorage,
|
KickoffTaskOutputsSQLiteStorage,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,11 +7,7 @@ from typing import Any, Dict, List, Set, Tuple
|
|||||||
import click
|
import click
|
||||||
import tomli
|
import tomli
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
|
||||||
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
|
||||||
from crewai.cli.fetch_crew_inputs import fetch_crew_inputs
|
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.task import Task
|
|
||||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
@@ -23,25 +19,44 @@ def run_chat():
|
|||||||
Exits if crew_name or crew_description are missing.
|
Exits if crew_name or crew_description are missing.
|
||||||
"""
|
"""
|
||||||
crew, crew_name = load_crew_and_name()
|
crew, crew_name = load_crew_and_name()
|
||||||
click.secho("\nFetching the Chat LLM...", fg="cyan")
|
chat_llm = initialize_chat_llm(crew)
|
||||||
try:
|
|
||||||
chat_llm = create_llm(crew.chat_llm)
|
|
||||||
except Exception as e:
|
|
||||||
click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red")
|
|
||||||
return
|
|
||||||
if not chat_llm:
|
if not chat_llm:
|
||||||
click.secho("No valid Chat LLM returned. Exiting.", fg="red")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate crew chat inputs automatically
|
|
||||||
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
|
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
|
||||||
print("crew_inputs:", crew_chat_inputs)
|
|
||||||
|
|
||||||
# Generate a tool schema from the crew inputs
|
|
||||||
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
||||||
print("crew_tool_schema:", crew_tool_schema)
|
system_message = build_system_message(crew_chat_inputs)
|
||||||
|
|
||||||
# Build initial system message
|
messages = [
|
||||||
|
{"role": "system", "content": system_message},
|
||||||
|
]
|
||||||
|
available_functions = {
|
||||||
|
crew_chat_inputs.crew_name: create_tool_function(crew, messages),
|
||||||
|
}
|
||||||
|
|
||||||
|
click.secho(
|
||||||
|
"\nEntering an interactive chat loop with function-calling.\n"
|
||||||
|
"Type 'exit' or Ctrl+C to quit.\n",
|
||||||
|
fg="cyan",
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_chat_llm(crew: Crew) -> Any:
|
||||||
|
"""Initializes the chat LLM and handles exceptions."""
|
||||||
|
try:
|
||||||
|
return create_llm(crew.chat_llm)
|
||||||
|
except Exception as e:
|
||||||
|
click.secho(
|
||||||
|
f"Unable to find a Chat LLM. Please make sure you set chat_llm on the crew: {e}",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||||
|
"""Builds the initial system message for the chat."""
|
||||||
required_fields_str = (
|
required_fields_str = (
|
||||||
", ".join(
|
", ".join(
|
||||||
f"{field.name} (desc: {field.description or 'n/a'})"
|
f"{field.name} (desc: {field.description or 'n/a'})"
|
||||||
@@ -50,7 +65,7 @@ def run_chat():
|
|||||||
or "(No required fields detected)"
|
or "(No required fields detected)"
|
||||||
)
|
)
|
||||||
|
|
||||||
system_message = (
|
return (
|
||||||
"You are a helpful AI assistant for the CrewAI platform. "
|
"You are a helpful AI assistant for the CrewAI platform. "
|
||||||
"Your primary purpose is to assist users with the crew's specific tasks. "
|
"Your primary purpose is to assist users with the crew's specific tasks. "
|
||||||
"You can answer general questions, but should guide users back to the crew's purpose afterward. "
|
"You can answer general questions, but should guide users back to the crew's purpose afterward. "
|
||||||
@@ -66,26 +81,18 @@ def run_chat():
|
|||||||
f"\nCrew Description: {crew_chat_inputs.crew_description}"
|
f"\nCrew Description: {crew_chat_inputs.crew_description}"
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system_message},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Create a wrapper function that captures 'crew' and 'messages' from the enclosing scope
|
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||||
|
"""Creates a wrapper function for running the crew tool with messages."""
|
||||||
|
|
||||||
def run_crew_tool_with_messages(**kwargs):
|
def run_crew_tool_with_messages(**kwargs):
|
||||||
return run_crew_tool(crew, messages, **kwargs)
|
return run_crew_tool(crew, messages, **kwargs)
|
||||||
|
|
||||||
# Prepare available_functions with the wrapper function
|
return run_crew_tool_with_messages
|
||||||
available_functions = {
|
|
||||||
crew_chat_inputs.crew_name: run_crew_tool_with_messages,
|
|
||||||
}
|
|
||||||
|
|
||||||
click.secho(
|
|
||||||
"\nEntering an interactive chat loop with function-calling.\n"
|
|
||||||
"Type 'exit' or Ctrl+C to quit.\n",
|
|
||||||
fg="cyan",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Main chat loop
|
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||||
|
"""Main chat loop for interacting with the user."""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
user_input = click.prompt("You", type=str)
|
user_input = click.prompt("You", type=str)
|
||||||
@@ -93,20 +100,14 @@ def run_chat():
|
|||||||
click.echo("Exiting chat. Goodbye!")
|
click.echo("Exiting chat. Goodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Append user message
|
|
||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
# Invoke the LLM, passing tools and available_functions
|
|
||||||
final_response = chat_llm.call(
|
final_response = chat_llm.call(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=[crew_tool_schema],
|
tools=[crew_tool_schema],
|
||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append assistant's reply
|
|
||||||
messages.append({"role": "assistant", "content": final_response})
|
messages.append({"role": "assistant", "content": final_response})
|
||||||
|
|
||||||
# Display assistant's reply
|
|
||||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -165,7 +166,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Serialize 'messages' to JSON string before adding to kwargs
|
# Serialize 'messages' to JSON string before adding to kwargs
|
||||||
kwargs['crew_chat_messages'] = json.dumps(messages)
|
kwargs["crew_chat_messages"] = json.dumps(messages)
|
||||||
|
|
||||||
# Run the crew with the provided inputs
|
# Run the crew with the provided inputs
|
||||||
crew_output = crew.kickoff(inputs=kwargs)
|
crew_output = crew.kickoff(inputs=kwargs)
|
||||||
@@ -184,7 +185,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
|||||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||||
"""
|
"""
|
||||||
Loads the crew by importing the crew class from the user's project.
|
Loads the crew by importing the crew class from the user's project.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
||||||
"""
|
"""
|
||||||
@@ -258,9 +259,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
|||||||
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
||||||
|
|
||||||
return ChatInputs(
|
return ChatInputs(
|
||||||
crew_name=crew_name,
|
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
|
||||||
crew_description=crew_description,
|
|
||||||
inputs=input_fields
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -307,18 +306,31 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
|||||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||||
|
|
||||||
for task in crew.tasks:
|
for task in crew.tasks:
|
||||||
if f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output:
|
if (
|
||||||
|
f"{{{input_name}}}" in task.description
|
||||||
|
or f"{{{input_name}}}" in task.expected_output
|
||||||
|
):
|
||||||
# Replace placeholders with input names
|
# Replace placeholders with input names
|
||||||
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
|
task_description = placeholder_pattern.sub(
|
||||||
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
|
lambda m: m.group(1), task.description
|
||||||
|
)
|
||||||
|
expected_output = placeholder_pattern.sub(
|
||||||
|
lambda m: m.group(1), task.expected_output
|
||||||
|
)
|
||||||
context_texts.append(f"Task Description: {task_description}")
|
context_texts.append(f"Task Description: {task_description}")
|
||||||
context_texts.append(f"Expected Output: {expected_output}")
|
context_texts.append(f"Expected Output: {expected_output}")
|
||||||
for agent in crew.agents:
|
for agent in crew.agents:
|
||||||
if f"{{{input_name}}}" in agent.role or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory:
|
if (
|
||||||
|
f"{{{input_name}}}" in agent.role
|
||||||
|
or f"{{{input_name}}}" in agent.goal
|
||||||
|
or f"{{{input_name}}}" in agent.backstory
|
||||||
|
):
|
||||||
# Replace placeholders with input names
|
# Replace placeholders with input names
|
||||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
||||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
||||||
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory)
|
agent_backstory = placeholder_pattern.sub(
|
||||||
|
lambda m: m.group(1), agent.backstory
|
||||||
|
)
|
||||||
context_texts.append(f"Agent Role: {agent_role}")
|
context_texts.append(f"Agent Role: {agent_role}")
|
||||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||||
context_texts.append(f"Agent Backstory: {agent_backstory}")
|
context_texts.append(f"Agent Backstory: {agent_backstory}")
|
||||||
@@ -357,8 +369,12 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
|||||||
|
|
||||||
for task in crew.tasks:
|
for task in crew.tasks:
|
||||||
# Replace placeholders with input names
|
# Replace placeholders with input names
|
||||||
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
|
task_description = placeholder_pattern.sub(
|
||||||
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
|
lambda m: m.group(1), task.description
|
||||||
|
)
|
||||||
|
expected_output = placeholder_pattern.sub(
|
||||||
|
lambda m: m.group(1), task.expected_output
|
||||||
|
)
|
||||||
context_texts.append(f"Task Description: {task_description}")
|
context_texts.append(f"Task Description: {task_description}")
|
||||||
context_texts.append(f"Expected Output: {expected_output}")
|
context_texts.append(f"Expected Output: {expected_output}")
|
||||||
for agent in crew.agents:
|
for agent in crew.agents:
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
import json
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import click
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from crewai.cli.utils import read_toml
|
|
||||||
from crewai.cli.version import get_crewai_version
|
|
||||||
from crewai.llm import LLM
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_chat_llm() -> LLM:
|
|
||||||
"""
|
|
||||||
Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name),
|
|
||||||
parsing its JSON stdout, and returning an LLM instance.
|
|
||||||
|
|
||||||
This expects the script "fetch_chat_llm" to print out JSON that represents the
|
|
||||||
LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))).
|
|
||||||
|
|
||||||
Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# You may change this command to match whatever's in your pyproject.toml [project.scripts].
|
|
||||||
command = ["uv", "run", "fetch_chat_llm"]
|
|
||||||
|
|
||||||
crewai_version = get_crewai_version()
|
|
||||||
min_required_version = "0.87.0" # Adjust as needed
|
|
||||||
|
|
||||||
pyproject_data = read_toml()
|
|
||||||
|
|
||||||
# If old poetry-based setup is detected and version is below min_required_version
|
|
||||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
|
||||||
version.parse(crewai_version) < version.parse(min_required_version)
|
|
||||||
):
|
|
||||||
click.secho(
|
|
||||||
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
|
|
||||||
f"Please run `crewai update` to transition your pyproject.toml to use uv.",
|
|
||||||
fg="red",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize a reference to your LLM
|
|
||||||
llm_instance = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
|
||||||
stdout_lines = result.stdout.strip().splitlines()
|
|
||||||
|
|
||||||
# Find the line that contains the JSON data
|
|
||||||
json_line = next(
|
|
||||||
(
|
|
||||||
line
|
|
||||||
for line in stdout_lines
|
|
||||||
if line.startswith("{") and line.endswith("}")
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not json_line:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No valid JSON output received from `fetch_chat_llm` command."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm_data = json.loads(json_line)
|
|
||||||
llm_instance = LLM.from_dict(llm_data)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"An unexpected error occurred while fetching chat LLM: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if not llm_instance:
|
|
||||||
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
|
|
||||||
|
|
||||||
return llm_instance
|
|
||||||
@@ -19,22 +19,11 @@ def run():
|
|||||||
Usage example:
|
Usage example:
|
||||||
uv run run_crew -- --topic="New Topic" --some_other_field="Value"
|
uv run run_crew -- --topic="New Topic" --some_other_field="Value"
|
||||||
"""
|
"""
|
||||||
# Default inputs
|
|
||||||
inputs = {
|
inputs = {
|
||||||
'topic': 'AI LLMs'
|
'topic': 'AI LLMs'
|
||||||
# Add any other default fields here
|
# Add any other default fields here
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1) Gather overrides from sys.argv
|
|
||||||
# sys.argv might look like: ['run_crew', '--topic=NewTopic']
|
|
||||||
# But be aware that if you're calling "uv run run_crew", sys.argv might have
|
|
||||||
# additional items. So we typically skip the first 1 or 2 items to get only overrides.
|
|
||||||
overrides = parse_cli_overrides(sys.argv[1:])
|
|
||||||
|
|
||||||
# 2) Merge the overrides into defaults
|
|
||||||
inputs.update(overrides)
|
|
||||||
|
|
||||||
# 3) Kick off the crew with final inputs
|
|
||||||
try:
|
try:
|
||||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
{{crew_name}}().crew().kickoff(inputs=inputs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -76,93 +65,3 @@ def test():
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while testing the crew: {e}")
|
raise Exception(f"An error occurred while testing the crew: {e}")
|
||||||
|
|
||||||
def fetch_inputs():
|
|
||||||
"""
|
|
||||||
Command that gathers required placeholders/inputs from the Crew, then
|
|
||||||
prints them as JSON to stdout so external scripts can parse them easily.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
crew = {{crew_name}}().crew()
|
|
||||||
crew_inputs = crew.fetch_inputs()
|
|
||||||
json_string = json.dumps(list(crew_inputs))
|
|
||||||
print(json_string)
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"An error occurred while fetching inputs: {e}")
|
|
||||||
|
|
||||||
def fetch_chat_llm():
|
|
||||||
"""
|
|
||||||
Command that fetches the 'chat_llm' property from the Crew,
|
|
||||||
instantiates it via create_llm(),
|
|
||||||
and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
crew = {{crew_name}}().crew()
|
|
||||||
raw_chat_llm = getattr(crew, "chat_llm", None)
|
|
||||||
|
|
||||||
if not raw_chat_llm:
|
|
||||||
# If the crew doesn't have chat_llm, fallback to create_llm(None)
|
|
||||||
final_llm = create_llm(None)
|
|
||||||
else:
|
|
||||||
# raw_chat_llm might be a dict, or an LLM, or something else
|
|
||||||
final_llm = create_llm(raw_chat_llm)
|
|
||||||
|
|
||||||
if final_llm:
|
|
||||||
# Print the final LLM as JSON, so fetch_chat_llm.py can parse it
|
|
||||||
from crewai.llm import LLM # Import here to avoid circular references
|
|
||||||
|
|
||||||
# Make sure it's an instance of the LLM class:
|
|
||||||
if isinstance(final_llm, LLM):
|
|
||||||
print(json.dumps(final_llm.to_dict()))
|
|
||||||
else:
|
|
||||||
# If somehow it's not an LLM, try to interpret as a dict
|
|
||||||
# or revert to an empty fallback
|
|
||||||
if isinstance(final_llm, dict):
|
|
||||||
print(json.dumps(final_llm))
|
|
||||||
else:
|
|
||||||
print(json.dumps({}))
|
|
||||||
else:
|
|
||||||
print(json.dumps({}))
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"An error occurred while fetching chat LLM: {e}")
|
|
||||||
|
|
||||||
# TODO: Talk to Joao about making using LLM calls to analyze the crew
|
|
||||||
# and generate all of this information automatically
|
|
||||||
def fetch_chat_inputs():
|
|
||||||
"""
|
|
||||||
Command that fetches the 'chat_inputs' property from the Crew,
|
|
||||||
and prints it as JSON to stdout.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
crew = {{crew_name}}().crew()
|
|
||||||
raw_chat_inputs = getattr(crew, "chat_inputs", None)
|
|
||||||
|
|
||||||
if raw_chat_inputs:
|
|
||||||
# Convert to dictionary to print JSON
|
|
||||||
print(json.dumps(raw_chat_inputs.model_dump()))
|
|
||||||
else:
|
|
||||||
# If crew.chat_inputs is None or empty, print an empty JSON
|
|
||||||
print(json.dumps({}))
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"An error occurred while fetching chat inputs: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def parse_cli_overrides(args_list) -> dict:
|
|
||||||
"""
|
|
||||||
Parse arguments in the form of --key=value from a list of CLI arguments.
|
|
||||||
Return them as a dict. For example:
|
|
||||||
['--topic=AI LLMs', '--username=John'] => {'topic': 'AI LLMs', 'username': 'John'}
|
|
||||||
"""
|
|
||||||
overrides = {}
|
|
||||||
for arg in args_list:
|
|
||||||
if arg.startswith("--"):
|
|
||||||
# remove the leading --
|
|
||||||
trimmed = arg[2:]
|
|
||||||
if "=" in trimmed:
|
|
||||||
key, val = trimmed.split("=", 1)
|
|
||||||
overrides[key] = val
|
|
||||||
else:
|
|
||||||
# If someone passed something like --topic (no =),
|
|
||||||
# either handle differently or ignore
|
|
||||||
pass
|
|
||||||
return overrides
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from importlib import resources
|
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -179,35 +178,6 @@ class LLM:
|
|||||||
"callbacks": self.callbacks,
|
"callbacks": self.callbacks,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: dict) -> "LLM":
|
|
||||||
"""
|
|
||||||
Create an LLM instance from a dict.
|
|
||||||
We assume the dict has all relevant keys that match what's in the constructor.
|
|
||||||
"""
|
|
||||||
known_fields = {}
|
|
||||||
known_fields["model"] = data.pop("model", None)
|
|
||||||
known_fields["timeout"] = data.pop("timeout", None)
|
|
||||||
known_fields["temperature"] = data.pop("temperature", None)
|
|
||||||
known_fields["top_p"] = data.pop("top_p", None)
|
|
||||||
known_fields["n"] = data.pop("n", None)
|
|
||||||
known_fields["stop"] = data.pop("stop", None)
|
|
||||||
known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None)
|
|
||||||
known_fields["max_tokens"] = data.pop("max_tokens", None)
|
|
||||||
known_fields["presence_penalty"] = data.pop("presence_penalty", None)
|
|
||||||
known_fields["frequency_penalty"] = data.pop("frequency_penalty", None)
|
|
||||||
known_fields["logit_bias"] = data.pop("logit_bias", None)
|
|
||||||
known_fields["response_format"] = data.pop("response_format", None)
|
|
||||||
known_fields["seed"] = data.pop("seed", None)
|
|
||||||
known_fields["logprobs"] = data.pop("logprobs", None)
|
|
||||||
known_fields["top_logprobs"] = data.pop("top_logprobs", None)
|
|
||||||
known_fields["base_url"] = data.pop("base_url", None)
|
|
||||||
known_fields["api_version"] = data.pop("api_version", None)
|
|
||||||
known_fields["api_key"] = data.pop("api_key", None)
|
|
||||||
known_fields["callbacks"] = data.pop("callbacks", None)
|
|
||||||
|
|
||||||
return cls(**known_fields, **data)
|
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -14,10 +14,7 @@ class ChatInputField(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = Field(..., description="The name of the input field")
|
name: str = Field(..., description="The name of the input field")
|
||||||
description: str = Field(
|
description: str = Field(..., description="A short description of the input field")
|
||||||
...,
|
|
||||||
description="A short description of the input field",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatInputs(BaseModel):
|
class ChatInputs(BaseModel):
|
||||||
@@ -36,8 +33,7 @@ class ChatInputs(BaseModel):
|
|||||||
|
|
||||||
crew_name: str = Field(..., description="The name of the crew")
|
crew_name: str = Field(..., description="The name of the crew")
|
||||||
crew_description: str = Field(
|
crew_description: str = Field(
|
||||||
...,
|
..., description="A description of the crew's purpose"
|
||||||
description="A description of the crew's purpose",
|
|
||||||
)
|
)
|
||||||
inputs: List[ChatInputField] = Field(
|
inputs: List[ChatInputField] = Field(
|
||||||
default_factory=list, description="A list of input fields for the crew"
|
default_factory=list, description="A list of input fields for the crew"
|
||||||
|
|||||||
96
tests/utilities/test_llm_utils.py
Normal file
96
tests/utilities/test_llm_utils.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from litellm.exceptions import BadRequestError
|
||||||
|
|
||||||
|
from crewai.llm import LLM
|
||||||
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_llm_instance():
|
||||||
|
existing_llm = LLM(model="gpt-4o")
|
||||||
|
llm = create_llm(llm_value=existing_llm)
|
||||||
|
assert llm is existing_llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_valid_model_string():
|
||||||
|
llm = create_llm(llm_value="gpt-4o")
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_invalid_model_string():
|
||||||
|
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||||
|
llm = create_llm(llm_value="invalid-model")
|
||||||
|
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_unknown_object_missing_attributes():
|
||||||
|
class UnknownObject:
|
||||||
|
pass
|
||||||
|
|
||||||
|
unknown_obj = UnknownObject()
|
||||||
|
llm = create_llm(llm_value=unknown_obj)
|
||||||
|
|
||||||
|
# Attempt to call the LLM and expect it to raise an error due to missing attributes
|
||||||
|
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||||
|
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_none_uses_default_model():
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
with patch("crewai.cli.constants.DEFAULT_LLM_MODEL", "gpt-4o"):
|
||||||
|
llm = create_llm(llm_value=None)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gpt-4o-mini"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_unknown_object():
|
||||||
|
class UnknownObject:
|
||||||
|
model_name = "gpt-4o"
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 1500
|
||||||
|
|
||||||
|
unknown_obj = UnknownObject()
|
||||||
|
llm = create_llm(llm_value=unknown_obj)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gpt-4o"
|
||||||
|
assert llm.temperature == 0.7
|
||||||
|
assert llm.max_tokens == 1500
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_from_env_with_unaccepted_attributes():
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
|
||||||
|
"AWS_ACCESS_KEY_ID": "fake-access-key",
|
||||||
|
"AWS_SECRET_ACCESS_KEY": "fake-secret-key",
|
||||||
|
"AWS_REGION_NAME": "us-west-2",
|
||||||
|
},
|
||||||
|
):
|
||||||
|
llm = create_llm(llm_value=None)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gpt-3.5-turbo"
|
||||||
|
assert not hasattr(llm, "AWS_ACCESS_KEY_ID")
|
||||||
|
assert not hasattr(llm, "AWS_SECRET_ACCESS_KEY")
|
||||||
|
assert not hasattr(llm, "AWS_REGION_NAME")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_partial_attributes():
|
||||||
|
class PartialAttributes:
|
||||||
|
model_name = "gpt-4o"
|
||||||
|
# temperature is missing
|
||||||
|
|
||||||
|
obj = PartialAttributes()
|
||||||
|
llm = create_llm(llm_value=obj)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gpt-4o"
|
||||||
|
assert llm.temperature is None # Should handle missing attributes gracefully
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_with_invalid_type():
|
||||||
|
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||||
|
llm = create_llm(llm_value=42)
|
||||||
|
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||||
Reference in New Issue
Block a user