adding tests. cleaing up pr.

This commit is contained in:
Brandon Hancock
2025-01-06 13:12:30 -05:00
parent fadd423771
commit cea5a28db1
7 changed files with 166 additions and 271 deletions

View File

@@ -8,7 +8,6 @@ from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat
from crewai.cli.fetch_chat_llm import fetch_chat_llm
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)

View File

@@ -7,11 +7,7 @@ from typing import Any, Dict, List, Set, Tuple
import click
import tomli
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.cli.fetch_chat_llm import fetch_chat_llm
from crewai.cli.fetch_crew_inputs import fetch_crew_inputs
from crewai.crew import Crew
from crewai.task import Task
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
@@ -23,25 +19,44 @@ def run_chat():
Exits if crew_name or crew_description are missing.
"""
crew, crew_name = load_crew_and_name()
click.secho("\nFetching the Chat LLM...", fg="cyan")
try:
chat_llm = create_llm(crew.chat_llm)
except Exception as e:
click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red")
return
chat_llm = initialize_chat_llm(crew)
if not chat_llm:
click.secho("No valid Chat LLM returned. Exiting.", fg="red")
return
# Generate crew chat inputs automatically
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
print("crew_inputs:", crew_chat_inputs)
# Generate a tool schema from the crew inputs
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
print("crew_tool_schema:", crew_tool_schema)
system_message = build_system_message(crew_chat_inputs)
# Build initial system message
messages = [
{"role": "system", "content": system_message},
]
available_functions = {
crew_chat_inputs.crew_name: create_tool_function(crew, messages),
}
click.secho(
"\nEntering an interactive chat loop with function-calling.\n"
"Type 'exit' or Ctrl+C to quit.\n",
fg="cyan",
)
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
def initialize_chat_llm(crew: Crew) -> Any:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
except Exception as e:
click.secho(
f"Unable to find a Chat LLM. Please make sure you set chat_llm on the crew: {e}",
fg="red",
)
return None
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
"""Builds the initial system message for the chat."""
required_fields_str = (
", ".join(
f"{field.name} (desc: {field.description or 'n/a'})"
@@ -50,7 +65,7 @@ def run_chat():
or "(No required fields detected)"
)
system_message = (
return (
"You are a helpful AI assistant for the CrewAI platform. "
"Your primary purpose is to assist users with the crew's specific tasks. "
"You can answer general questions, but should guide users back to the crew's purpose afterward. "
@@ -66,26 +81,18 @@ def run_chat():
f"\nCrew Description: {crew_chat_inputs.crew_description}"
)
messages = [
{"role": "system", "content": system_message},
]
# Create a wrapper function that captures 'crew' and 'messages' from the enclosing scope
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs):
return run_crew_tool(crew, messages, **kwargs)
# Prepare available_functions with the wrapper function
available_functions = {
crew_chat_inputs.crew_name: run_crew_tool_with_messages,
}
return run_crew_tool_with_messages
click.secho(
"\nEntering an interactive chat loop with function-calling.\n"
"Type 'exit' or Ctrl+C to quit.\n",
fg="cyan",
)
# Main chat loop
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
"""Main chat loop for interacting with the user."""
while True:
try:
user_input = click.prompt("You", type=str)
@@ -93,20 +100,14 @@ def run_chat():
click.echo("Exiting chat. Goodbye!")
break
# Append user message
messages.append({"role": "user", "content": user_input})
# Invoke the LLM, passing tools and available_functions
final_response = chat_llm.call(
messages=messages,
tools=[crew_tool_schema],
available_functions=available_functions,
)
# Append assistant's reply
messages.append({"role": "assistant", "content": final_response})
# Display assistant's reply
click.secho(f"\nAssistant: {final_response}\n", fg="green")
except KeyboardInterrupt:
@@ -165,7 +166,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
"""
try:
# Serialize 'messages' to JSON string before adding to kwargs
kwargs['crew_chat_messages'] = json.dumps(messages)
kwargs["crew_chat_messages"] = json.dumps(messages)
# Run the crew with the provided inputs
crew_output = crew.kickoff(inputs=kwargs)
@@ -184,7 +185,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
def load_crew_and_name() -> Tuple[Crew, str]:
"""
Loads the crew by importing the crew class from the user's project.
Returns:
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
"""
@@ -258,9 +259,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs(
crew_name=crew_name,
crew_description=crew_description,
inputs=input_fields
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
)
@@ -307,18 +306,31 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
placeholder_pattern = re.compile(r"\{(.+?)\}")
for task in crew.tasks:
if f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output:
if (
f"{{{input_name}}}" in task.description
or f"{{{input_name}}}" in task.expected_output
):
# Replace placeholders with input names
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output
)
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
for agent in crew.agents:
if f"{{{input_name}}}" in agent.role or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory:
if (
f"{{{input_name}}}" in agent.role
or f"{{{input_name}}}" in agent.goal
or f"{{{input_name}}}" in agent.backstory
):
# Replace placeholders with input names
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory)
agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory
)
context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}")
context_texts.append(f"Agent Backstory: {agent_backstory}")
@@ -357,8 +369,12 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
for task in crew.tasks:
# Replace placeholders with input names
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
task_description = placeholder_pattern.sub(
lambda m: m.group(1), task.description
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output
)
context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}")
for agent in crew.agents:

View File

@@ -1,81 +0,0 @@
import json
import subprocess
import click
from packaging import version
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.llm import LLM
def fetch_chat_llm() -> LLM:
"""
Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name),
parsing its JSON stdout, and returning an LLM instance.
This expects the script "fetch_chat_llm" to print out JSON that represents the
LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))).
Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError.
"""
# You may change this command to match whatever's in your pyproject.toml [project.scripts].
command = ["uv", "run", "fetch_chat_llm"]
crewai_version = get_crewai_version()
min_required_version = "0.87.0" # Adjust as needed
pyproject_data = read_toml()
# If old poetry-based setup is detected and version is below min_required_version
if pyproject_data.get("tool", {}).get("poetry") and (
version.parse(crewai_version) < version.parse(min_required_version)
):
click.secho(
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
f"Please run `crewai update` to transition your pyproject.toml to use uv.",
fg="red",
)
# Initialize a reference to your LLM
llm_instance = None
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
stdout_lines = result.stdout.strip().splitlines()
# Find the line that contains the JSON data
json_line = next(
(
line
for line in stdout_lines
if line.startswith("{") and line.endswith("}")
),
None,
)
if not json_line:
raise RuntimeError(
"No valid JSON output received from `fetch_chat_llm` command."
)
try:
llm_data = json.loads(json_line)
llm_instance = LLM.from_dict(llm_data)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
) from e
except subprocess.CalledProcessError as e:
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while fetching chat LLM: {e}"
) from e
if not llm_instance:
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
return llm_instance

View File

@@ -19,22 +19,11 @@ def run():
Usage example:
uv run run_crew -- --topic="New Topic" --some_other_field="Value"
"""
# Default inputs
inputs = {
'topic': 'AI LLMs'
# Add any other default fields here
}
# 1) Gather overrides from sys.argv
# sys.argv might look like: ['run_crew', '--topic=NewTopic']
# But be aware that if you're calling "uv run run_crew", sys.argv might have
# additional items. So we typically skip the first 1 or 2 items to get only overrides.
overrides = parse_cli_overrides(sys.argv[1:])
# 2) Merge the overrides into defaults
inputs.update(overrides)
# 3) Kick off the crew with final inputs
try:
{{crew_name}}().crew().kickoff(inputs=inputs)
except Exception as e:
@@ -76,93 +65,3 @@ def test():
except Exception as e:
raise Exception(f"An error occurred while testing the crew: {e}")
def fetch_inputs():
"""
Command that gathers required placeholders/inputs from the Crew, then
prints them as JSON to stdout so external scripts can parse them easily.
"""
try:
crew = {{crew_name}}().crew()
crew_inputs = crew.fetch_inputs()
json_string = json.dumps(list(crew_inputs))
print(json_string)
except Exception as e:
raise Exception(f"An error occurred while fetching inputs: {e}")
def fetch_chat_llm():
"""
Command that fetches the 'chat_llm' property from the Crew,
instantiates it via create_llm(),
and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout.
"""
try:
crew = {{crew_name}}().crew()
raw_chat_llm = getattr(crew, "chat_llm", None)
if not raw_chat_llm:
# If the crew doesn't have chat_llm, fallback to create_llm(None)
final_llm = create_llm(None)
else:
# raw_chat_llm might be a dict, or an LLM, or something else
final_llm = create_llm(raw_chat_llm)
if final_llm:
# Print the final LLM as JSON, so fetch_chat_llm.py can parse it
from crewai.llm import LLM # Import here to avoid circular references
# Make sure it's an instance of the LLM class:
if isinstance(final_llm, LLM):
print(json.dumps(final_llm.to_dict()))
else:
# If somehow it's not an LLM, try to interpret as a dict
# or revert to an empty fallback
if isinstance(final_llm, dict):
print(json.dumps(final_llm))
else:
print(json.dumps({}))
else:
print(json.dumps({}))
except Exception as e:
raise Exception(f"An error occurred while fetching chat LLM: {e}")
# TODO: Talk to Joao about making using LLM calls to analyze the crew
# and generate all of this information automatically
def fetch_chat_inputs():
"""
Command that fetches the 'chat_inputs' property from the Crew,
and prints it as JSON to stdout.
"""
try:
crew = {{crew_name}}().crew()
raw_chat_inputs = getattr(crew, "chat_inputs", None)
if raw_chat_inputs:
# Convert to dictionary to print JSON
print(json.dumps(raw_chat_inputs.model_dump()))
else:
# If crew.chat_inputs is None or empty, print an empty JSON
print(json.dumps({}))
except Exception as e:
raise Exception(f"An error occurred while fetching chat inputs: {e}")
def parse_cli_overrides(args_list) -> dict:
"""
Parse arguments in the form of --key=value from a list of CLI arguments.
Return them as a dict. For example:
['--topic=AI LLMs', '--username=John'] => {'topic': 'AI LLMs', 'username': 'John'}
"""
overrides = {}
for arg in args_list:
if arg.startswith("--"):
# remove the leading --
trimmed = arg[2:]
if "=" in trimmed:
key, val = trimmed.split("=", 1)
overrides[key] = val
else:
# If someone passed something like --topic (no =),
# either handle differently or ignore
pass
return overrides

View File

@@ -5,7 +5,6 @@ import sys
import threading
import warnings
from contextlib import contextmanager
from importlib import resources
from typing import Any, Dict, List, Optional, Union, cast
from dotenv import load_dotenv
@@ -179,35 +178,6 @@ class LLM:
"callbacks": self.callbacks,
}
@classmethod
def from_dict(cls, data: dict) -> "LLM":
"""
Create an LLM instance from a dict.
We assume the dict has all relevant keys that match what's in the constructor.
"""
known_fields = {}
known_fields["model"] = data.pop("model", None)
known_fields["timeout"] = data.pop("timeout", None)
known_fields["temperature"] = data.pop("temperature", None)
known_fields["top_p"] = data.pop("top_p", None)
known_fields["n"] = data.pop("n", None)
known_fields["stop"] = data.pop("stop", None)
known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None)
known_fields["max_tokens"] = data.pop("max_tokens", None)
known_fields["presence_penalty"] = data.pop("presence_penalty", None)
known_fields["frequency_penalty"] = data.pop("frequency_penalty", None)
known_fields["logit_bias"] = data.pop("logit_bias", None)
known_fields["response_format"] = data.pop("response_format", None)
known_fields["seed"] = data.pop("seed", None)
known_fields["logprobs"] = data.pop("logprobs", None)
known_fields["top_logprobs"] = data.pop("top_logprobs", None)
known_fields["base_url"] = data.pop("base_url", None)
known_fields["api_version"] = data.pop("api_version", None)
known_fields["api_key"] = data.pop("api_key", None)
known_fields["callbacks"] = data.pop("callbacks", None)
return cls(**known_fields, **data)
def call(
self,
messages: List[Dict[str, str]],

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from pydantic import BaseModel, Field
@@ -14,10 +14,7 @@ class ChatInputField(BaseModel):
"""
name: str = Field(..., description="The name of the input field")
description: str = Field(
...,
description="A short description of the input field",
)
description: str = Field(..., description="A short description of the input field")
class ChatInputs(BaseModel):
@@ -36,8 +33,7 @@ class ChatInputs(BaseModel):
crew_name: str = Field(..., description="The name of the crew")
crew_description: str = Field(
...,
description="A description of the crew's purpose",
..., description="A description of the crew's purpose"
)
inputs: List[ChatInputField] = Field(
default_factory=list, description="A list of input fields for the crew"

View File

@@ -0,0 +1,96 @@
import os
from unittest.mock import patch
import pytest
from litellm.exceptions import BadRequestError
from crewai.llm import LLM
from crewai.utilities.llm_utils import create_llm
def test_create_llm_with_llm_instance():
existing_llm = LLM(model="gpt-4o")
llm = create_llm(llm_value=existing_llm)
assert llm is existing_llm
def test_create_llm_with_valid_model_string():
llm = create_llm(llm_value="gpt-4o")
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
def test_create_llm_with_invalid_model_string():
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm = create_llm(llm_value="invalid-model")
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_unknown_object_missing_attributes():
class UnknownObject:
pass
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
# Attempt to call the LLM and expect it to raise an error due to missing attributes
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_none_uses_default_model():
with patch.dict(os.environ, {}, clear=True):
with patch("crewai.cli.constants.DEFAULT_LLM_MODEL", "gpt-4o"):
llm = create_llm(llm_value=None)
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o-mini"
def test_create_llm_with_unknown_object():
class UnknownObject:
model_name = "gpt-4o"
temperature = 0.7
max_tokens = 1500
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
assert llm.temperature == 0.7
assert llm.max_tokens == 1500
def test_create_llm_from_env_with_unaccepted_attributes():
with patch.dict(
os.environ,
{
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
"AWS_ACCESS_KEY_ID": "fake-access-key",
"AWS_SECRET_ACCESS_KEY": "fake-secret-key",
"AWS_REGION_NAME": "us-west-2",
},
):
llm = create_llm(llm_value=None)
assert isinstance(llm, LLM)
assert llm.model == "gpt-3.5-turbo"
assert not hasattr(llm, "AWS_ACCESS_KEY_ID")
assert not hasattr(llm, "AWS_SECRET_ACCESS_KEY")
assert not hasattr(llm, "AWS_REGION_NAME")
def test_create_llm_with_partial_attributes():
class PartialAttributes:
model_name = "gpt-4o"
# temperature is missing
obj = PartialAttributes()
llm = create_llm(llm_value=obj)
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
assert llm.temperature is None # Should handle missing attributes gracefully
def test_create_llm_with_invalid_type():
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm = create_llm(llm_value=42)
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])