adding tests. cleaing up pr.

This commit is contained in:
Brandon Hancock
2025-01-06 13:12:30 -05:00
parent fadd423771
commit cea5a28db1
7 changed files with 166 additions and 271 deletions

View File

@@ -8,7 +8,6 @@ from crewai.cli.add_crew_to_flow import add_crew_to_flow
from crewai.cli.create_crew import create_crew from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat from crewai.cli.crew_chat import run_chat
from crewai.cli.fetch_chat_llm import fetch_chat_llm
from crewai.memory.storage.kickoff_task_outputs_storage import ( from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage, KickoffTaskOutputsSQLiteStorage,
) )

View File

@@ -7,11 +7,7 @@ from typing import Any, Dict, List, Set, Tuple
import click import click
import tomli import tomli
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.cli.fetch_chat_llm import fetch_chat_llm
from crewai.cli.fetch_crew_inputs import fetch_crew_inputs
from crewai.crew import Crew from crewai.crew import Crew
from crewai.task import Task
from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm from crewai.utilities.llm_utils import create_llm
@@ -23,25 +19,44 @@ def run_chat():
Exits if crew_name or crew_description are missing. Exits if crew_name or crew_description are missing.
""" """
crew, crew_name = load_crew_and_name() crew, crew_name = load_crew_and_name()
click.secho("\nFetching the Chat LLM...", fg="cyan") chat_llm = initialize_chat_llm(crew)
try:
chat_llm = create_llm(crew.chat_llm)
except Exception as e:
click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red")
return
if not chat_llm: if not chat_llm:
click.secho("No valid Chat LLM returned. Exiting.", fg="red")
return return
# Generate crew chat inputs automatically
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm) crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
print("crew_inputs:", crew_chat_inputs)
# Generate a tool schema from the crew inputs
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
print("crew_tool_schema:", crew_tool_schema) system_message = build_system_message(crew_chat_inputs)
# Build initial system message messages = [
{"role": "system", "content": system_message},
]
available_functions = {
crew_chat_inputs.crew_name: create_tool_function(crew, messages),
}
click.secho(
"\nEntering an interactive chat loop with function-calling.\n"
"Type 'exit' or Ctrl+C to quit.\n",
fg="cyan",
)
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
def initialize_chat_llm(crew: Crew) -> Any:
"""Initializes the chat LLM and handles exceptions."""
try:
return create_llm(crew.chat_llm)
except Exception as e:
click.secho(
f"Unable to find a Chat LLM. Please make sure you set chat_llm on the crew: {e}",
fg="red",
)
return None
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
"""Builds the initial system message for the chat."""
required_fields_str = ( required_fields_str = (
", ".join( ", ".join(
f"{field.name} (desc: {field.description or 'n/a'})" f"{field.name} (desc: {field.description or 'n/a'})"
@@ -50,7 +65,7 @@ def run_chat():
or "(No required fields detected)" or "(No required fields detected)"
) )
system_message = ( return (
"You are a helpful AI assistant for the CrewAI platform. " "You are a helpful AI assistant for the CrewAI platform. "
"Your primary purpose is to assist users with the crew's specific tasks. " "Your primary purpose is to assist users with the crew's specific tasks. "
"You can answer general questions, but should guide users back to the crew's purpose afterward. " "You can answer general questions, but should guide users back to the crew's purpose afterward. "
@@ -66,26 +81,18 @@ def run_chat():
f"\nCrew Description: {crew_chat_inputs.crew_description}" f"\nCrew Description: {crew_chat_inputs.crew_description}"
) )
messages = [
{"role": "system", "content": system_message},
]
# Create a wrapper function that captures 'crew' and 'messages' from the enclosing scope def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs): def run_crew_tool_with_messages(**kwargs):
return run_crew_tool(crew, messages, **kwargs) return run_crew_tool(crew, messages, **kwargs)
# Prepare available_functions with the wrapper function return run_crew_tool_with_messages
available_functions = {
crew_chat_inputs.crew_name: run_crew_tool_with_messages,
}
click.secho(
"\nEntering an interactive chat loop with function-calling.\n"
"Type 'exit' or Ctrl+C to quit.\n",
fg="cyan",
)
# Main chat loop def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
"""Main chat loop for interacting with the user."""
while True: while True:
try: try:
user_input = click.prompt("You", type=str) user_input = click.prompt("You", type=str)
@@ -93,20 +100,14 @@ def run_chat():
click.echo("Exiting chat. Goodbye!") click.echo("Exiting chat. Goodbye!")
break break
# Append user message
messages.append({"role": "user", "content": user_input}) messages.append({"role": "user", "content": user_input})
# Invoke the LLM, passing tools and available_functions
final_response = chat_llm.call( final_response = chat_llm.call(
messages=messages, messages=messages,
tools=[crew_tool_schema], tools=[crew_tool_schema],
available_functions=available_functions, available_functions=available_functions,
) )
# Append assistant's reply
messages.append({"role": "assistant", "content": final_response}) messages.append({"role": "assistant", "content": final_response})
# Display assistant's reply
click.secho(f"\nAssistant: {final_response}\n", fg="green") click.secho(f"\nAssistant: {final_response}\n", fg="green")
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -165,7 +166,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
""" """
try: try:
# Serialize 'messages' to JSON string before adding to kwargs # Serialize 'messages' to JSON string before adding to kwargs
kwargs['crew_chat_messages'] = json.dumps(messages) kwargs["crew_chat_messages"] = json.dumps(messages)
# Run the crew with the provided inputs # Run the crew with the provided inputs
crew_output = crew.kickoff(inputs=kwargs) crew_output = crew.kickoff(inputs=kwargs)
@@ -184,7 +185,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
def load_crew_and_name() -> Tuple[Crew, str]: def load_crew_and_name() -> Tuple[Crew, str]:
""" """
Loads the crew by importing the crew class from the user's project. Loads the crew by importing the crew class from the user's project.
Returns: Returns:
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew. Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
""" """
@@ -258,9 +259,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
crew_description = generate_crew_description_with_ai(crew, chat_llm) crew_description = generate_crew_description_with_ai(crew, chat_llm)
return ChatInputs( return ChatInputs(
crew_name=crew_name, crew_name=crew_name, crew_description=crew_description, inputs=input_fields
crew_description=crew_description,
inputs=input_fields
) )
@@ -307,18 +306,31 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
placeholder_pattern = re.compile(r"\{(.+?)\}") placeholder_pattern = re.compile(r"\{(.+?)\}")
for task in crew.tasks: for task in crew.tasks:
if f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output: if (
f"{{{input_name}}}" in task.description
or f"{{{input_name}}}" in task.expected_output
):
# Replace placeholders with input names # Replace placeholders with input names
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description) task_description = placeholder_pattern.sub(
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output) lambda m: m.group(1), task.description
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output
)
context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}") context_texts.append(f"Expected Output: {expected_output}")
for agent in crew.agents: for agent in crew.agents:
if f"{{{input_name}}}" in agent.role or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory: if (
f"{{{input_name}}}" in agent.role
or f"{{{input_name}}}" in agent.goal
or f"{{{input_name}}}" in agent.backstory
):
# Replace placeholders with input names # Replace placeholders with input names
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role) agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal) agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory) agent_backstory = placeholder_pattern.sub(
lambda m: m.group(1), agent.backstory
)
context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Role: {agent_role}")
context_texts.append(f"Agent Goal: {agent_goal}") context_texts.append(f"Agent Goal: {agent_goal}")
context_texts.append(f"Agent Backstory: {agent_backstory}") context_texts.append(f"Agent Backstory: {agent_backstory}")
@@ -357,8 +369,12 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
for task in crew.tasks: for task in crew.tasks:
# Replace placeholders with input names # Replace placeholders with input names
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description) task_description = placeholder_pattern.sub(
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output) lambda m: m.group(1), task.description
)
expected_output = placeholder_pattern.sub(
lambda m: m.group(1), task.expected_output
)
context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Task Description: {task_description}")
context_texts.append(f"Expected Output: {expected_output}") context_texts.append(f"Expected Output: {expected_output}")
for agent in crew.agents: for agent in crew.agents:

View File

@@ -1,81 +0,0 @@
import json
import subprocess
import click
from packaging import version
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.llm import LLM
def fetch_chat_llm() -> LLM:
"""
Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name),
parsing its JSON stdout, and returning an LLM instance.
This expects the script "fetch_chat_llm" to print out JSON that represents the
LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))).
Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError.
"""
# You may change this command to match whatever's in your pyproject.toml [project.scripts].
command = ["uv", "run", "fetch_chat_llm"]
crewai_version = get_crewai_version()
min_required_version = "0.87.0" # Adjust as needed
pyproject_data = read_toml()
# If old poetry-based setup is detected and version is below min_required_version
if pyproject_data.get("tool", {}).get("poetry") and (
version.parse(crewai_version) < version.parse(min_required_version)
):
click.secho(
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
f"Please run `crewai update` to transition your pyproject.toml to use uv.",
fg="red",
)
# Initialize a reference to your LLM
llm_instance = None
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
stdout_lines = result.stdout.strip().splitlines()
# Find the line that contains the JSON data
json_line = next(
(
line
for line in stdout_lines
if line.startswith("{") and line.endswith("}")
),
None,
)
if not json_line:
raise RuntimeError(
"No valid JSON output received from `fetch_chat_llm` command."
)
try:
llm_data = json.loads(json_line)
llm_instance = LLM.from_dict(llm_data)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
) from e
except subprocess.CalledProcessError as e:
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
except Exception as e:
raise RuntimeError(
f"An unexpected error occurred while fetching chat LLM: {e}"
) from e
if not llm_instance:
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
return llm_instance

View File

@@ -19,22 +19,11 @@ def run():
Usage example: Usage example:
uv run run_crew -- --topic="New Topic" --some_other_field="Value" uv run run_crew -- --topic="New Topic" --some_other_field="Value"
""" """
# Default inputs
inputs = { inputs = {
'topic': 'AI LLMs' 'topic': 'AI LLMs'
# Add any other default fields here # Add any other default fields here
} }
# 1) Gather overrides from sys.argv
# sys.argv might look like: ['run_crew', '--topic=NewTopic']
# But be aware that if you're calling "uv run run_crew", sys.argv might have
# additional items. So we typically skip the first 1 or 2 items to get only overrides.
overrides = parse_cli_overrides(sys.argv[1:])
# 2) Merge the overrides into defaults
inputs.update(overrides)
# 3) Kick off the crew with final inputs
try: try:
{{crew_name}}().crew().kickoff(inputs=inputs) {{crew_name}}().crew().kickoff(inputs=inputs)
except Exception as e: except Exception as e:
@@ -76,93 +65,3 @@ def test():
except Exception as e: except Exception as e:
raise Exception(f"An error occurred while testing the crew: {e}") raise Exception(f"An error occurred while testing the crew: {e}")
def fetch_inputs():
"""
Command that gathers required placeholders/inputs from the Crew, then
prints them as JSON to stdout so external scripts can parse them easily.
"""
try:
crew = {{crew_name}}().crew()
crew_inputs = crew.fetch_inputs()
json_string = json.dumps(list(crew_inputs))
print(json_string)
except Exception as e:
raise Exception(f"An error occurred while fetching inputs: {e}")
def fetch_chat_llm():
"""
Command that fetches the 'chat_llm' property from the Crew,
instantiates it via create_llm(),
and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout.
"""
try:
crew = {{crew_name}}().crew()
raw_chat_llm = getattr(crew, "chat_llm", None)
if not raw_chat_llm:
# If the crew doesn't have chat_llm, fallback to create_llm(None)
final_llm = create_llm(None)
else:
# raw_chat_llm might be a dict, or an LLM, or something else
final_llm = create_llm(raw_chat_llm)
if final_llm:
# Print the final LLM as JSON, so fetch_chat_llm.py can parse it
from crewai.llm import LLM # Import here to avoid circular references
# Make sure it's an instance of the LLM class:
if isinstance(final_llm, LLM):
print(json.dumps(final_llm.to_dict()))
else:
# If somehow it's not an LLM, try to interpret as a dict
# or revert to an empty fallback
if isinstance(final_llm, dict):
print(json.dumps(final_llm))
else:
print(json.dumps({}))
else:
print(json.dumps({}))
except Exception as e:
raise Exception(f"An error occurred while fetching chat LLM: {e}")
# TODO: Talk to Joao about making using LLM calls to analyze the crew
# and generate all of this information automatically
def fetch_chat_inputs():
"""
Command that fetches the 'chat_inputs' property from the Crew,
and prints it as JSON to stdout.
"""
try:
crew = {{crew_name}}().crew()
raw_chat_inputs = getattr(crew, "chat_inputs", None)
if raw_chat_inputs:
# Convert to dictionary to print JSON
print(json.dumps(raw_chat_inputs.model_dump()))
else:
# If crew.chat_inputs is None or empty, print an empty JSON
print(json.dumps({}))
except Exception as e:
raise Exception(f"An error occurred while fetching chat inputs: {e}")
def parse_cli_overrides(args_list) -> dict:
"""
Parse arguments in the form of --key=value from a list of CLI arguments.
Return them as a dict. For example:
['--topic=AI LLMs', '--username=John'] => {'topic': 'AI LLMs', 'username': 'John'}
"""
overrides = {}
for arg in args_list:
if arg.startswith("--"):
# remove the leading --
trimmed = arg[2:]
if "=" in trimmed:
key, val = trimmed.split("=", 1)
overrides[key] = val
else:
# If someone passed something like --topic (no =),
# either handle differently or ignore
pass
return overrides

View File

@@ -5,7 +5,6 @@ import sys
import threading import threading
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from importlib import resources
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Union, cast
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -179,35 +178,6 @@ class LLM:
"callbacks": self.callbacks, "callbacks": self.callbacks,
} }
@classmethod
def from_dict(cls, data: dict) -> "LLM":
"""
Create an LLM instance from a dict.
We assume the dict has all relevant keys that match what's in the constructor.
"""
known_fields = {}
known_fields["model"] = data.pop("model", None)
known_fields["timeout"] = data.pop("timeout", None)
known_fields["temperature"] = data.pop("temperature", None)
known_fields["top_p"] = data.pop("top_p", None)
known_fields["n"] = data.pop("n", None)
known_fields["stop"] = data.pop("stop", None)
known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None)
known_fields["max_tokens"] = data.pop("max_tokens", None)
known_fields["presence_penalty"] = data.pop("presence_penalty", None)
known_fields["frequency_penalty"] = data.pop("frequency_penalty", None)
known_fields["logit_bias"] = data.pop("logit_bias", None)
known_fields["response_format"] = data.pop("response_format", None)
known_fields["seed"] = data.pop("seed", None)
known_fields["logprobs"] = data.pop("logprobs", None)
known_fields["top_logprobs"] = data.pop("top_logprobs", None)
known_fields["base_url"] = data.pop("base_url", None)
known_fields["api_version"] = data.pop("api_version", None)
known_fields["api_key"] = data.pop("api_key", None)
known_fields["callbacks"] = data.pop("callbacks", None)
return cls(**known_fields, **data)
def call( def call(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],

View File

@@ -1,4 +1,4 @@
from typing import List, Optional from typing import List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -14,10 +14,7 @@ class ChatInputField(BaseModel):
""" """
name: str = Field(..., description="The name of the input field") name: str = Field(..., description="The name of the input field")
description: str = Field( description: str = Field(..., description="A short description of the input field")
...,
description="A short description of the input field",
)
class ChatInputs(BaseModel): class ChatInputs(BaseModel):
@@ -36,8 +33,7 @@ class ChatInputs(BaseModel):
crew_name: str = Field(..., description="The name of the crew") crew_name: str = Field(..., description="The name of the crew")
crew_description: str = Field( crew_description: str = Field(
..., ..., description="A description of the crew's purpose"
description="A description of the crew's purpose",
) )
inputs: List[ChatInputField] = Field( inputs: List[ChatInputField] = Field(
default_factory=list, description="A list of input fields for the crew" default_factory=list, description="A list of input fields for the crew"

View File

@@ -0,0 +1,96 @@
import os
from unittest.mock import patch
import pytest
from litellm.exceptions import BadRequestError
from crewai.llm import LLM
from crewai.utilities.llm_utils import create_llm
def test_create_llm_with_llm_instance():
existing_llm = LLM(model="gpt-4o")
llm = create_llm(llm_value=existing_llm)
assert llm is existing_llm
def test_create_llm_with_valid_model_string():
llm = create_llm(llm_value="gpt-4o")
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
def test_create_llm_with_invalid_model_string():
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm = create_llm(llm_value="invalid-model")
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_unknown_object_missing_attributes():
class UnknownObject:
pass
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
# Attempt to call the LLM and expect it to raise an error due to missing attributes
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
def test_create_llm_with_none_uses_default_model():
with patch.dict(os.environ, {}, clear=True):
with patch("crewai.cli.constants.DEFAULT_LLM_MODEL", "gpt-4o"):
llm = create_llm(llm_value=None)
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o-mini"
def test_create_llm_with_unknown_object():
class UnknownObject:
model_name = "gpt-4o"
temperature = 0.7
max_tokens = 1500
unknown_obj = UnknownObject()
llm = create_llm(llm_value=unknown_obj)
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
assert llm.temperature == 0.7
assert llm.max_tokens == 1500
def test_create_llm_from_env_with_unaccepted_attributes():
with patch.dict(
os.environ,
{
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
"AWS_ACCESS_KEY_ID": "fake-access-key",
"AWS_SECRET_ACCESS_KEY": "fake-secret-key",
"AWS_REGION_NAME": "us-west-2",
},
):
llm = create_llm(llm_value=None)
assert isinstance(llm, LLM)
assert llm.model == "gpt-3.5-turbo"
assert not hasattr(llm, "AWS_ACCESS_KEY_ID")
assert not hasattr(llm, "AWS_SECRET_ACCESS_KEY")
assert not hasattr(llm, "AWS_REGION_NAME")
def test_create_llm_with_partial_attributes():
class PartialAttributes:
model_name = "gpt-4o"
# temperature is missing
obj = PartialAttributes()
llm = create_llm(llm_value=obj)
assert isinstance(llm, LLM)
assert llm.model == "gpt-4o"
assert llm.temperature is None # Should handle missing attributes gracefully
def test_create_llm_with_invalid_type():
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
llm = create_llm(llm_value=42)
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])