mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
adding tests. cleaing up pr.
This commit is contained in:
@@ -8,7 +8,6 @@ from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
@@ -7,11 +7,7 @@ from typing import Any, Dict, List, Set, Tuple
|
||||
import click
|
||||
import tomli
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
||||
from crewai.cli.fetch_crew_inputs import fetch_crew_inputs
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
@@ -23,25 +19,44 @@ def run_chat():
|
||||
Exits if crew_name or crew_description are missing.
|
||||
"""
|
||||
crew, crew_name = load_crew_and_name()
|
||||
click.secho("\nFetching the Chat LLM...", fg="cyan")
|
||||
try:
|
||||
chat_llm = create_llm(crew.chat_llm)
|
||||
except Exception as e:
|
||||
click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red")
|
||||
return
|
||||
chat_llm = initialize_chat_llm(crew)
|
||||
if not chat_llm:
|
||||
click.secho("No valid Chat LLM returned. Exiting.", fg="red")
|
||||
return
|
||||
|
||||
# Generate crew chat inputs automatically
|
||||
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
|
||||
print("crew_inputs:", crew_chat_inputs)
|
||||
|
||||
# Generate a tool schema from the crew inputs
|
||||
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
||||
print("crew_tool_schema:", crew_tool_schema)
|
||||
system_message = build_system_message(crew_chat_inputs)
|
||||
|
||||
# Build initial system message
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
]
|
||||
available_functions = {
|
||||
crew_chat_inputs.crew_name: create_tool_function(crew, messages),
|
||||
}
|
||||
|
||||
click.secho(
|
||||
"\nEntering an interactive chat loop with function-calling.\n"
|
||||
"Type 'exit' or Ctrl+C to quit.\n",
|
||||
fg="cyan",
|
||||
)
|
||||
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Any:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
except Exception as e:
|
||||
click.secho(
|
||||
f"Unable to find a Chat LLM. Please make sure you set chat_llm on the crew: {e}",
|
||||
fg="red",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
"""Builds the initial system message for the chat."""
|
||||
required_fields_str = (
|
||||
", ".join(
|
||||
f"{field.name} (desc: {field.description or 'n/a'})"
|
||||
@@ -50,7 +65,7 @@ def run_chat():
|
||||
or "(No required fields detected)"
|
||||
)
|
||||
|
||||
system_message = (
|
||||
return (
|
||||
"You are a helpful AI assistant for the CrewAI platform. "
|
||||
"Your primary purpose is to assist users with the crew's specific tasks. "
|
||||
"You can answer general questions, but should guide users back to the crew's purpose afterward. "
|
||||
@@ -66,26 +81,18 @@ def run_chat():
|
||||
f"\nCrew Description: {crew_chat_inputs.crew_description}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
]
|
||||
|
||||
# Create a wrapper function that captures 'crew' and 'messages' from the enclosing scope
|
||||
def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
return run_crew_tool(crew, messages, **kwargs)
|
||||
|
||||
# Prepare available_functions with the wrapper function
|
||||
available_functions = {
|
||||
crew_chat_inputs.crew_name: run_crew_tool_with_messages,
|
||||
}
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
click.secho(
|
||||
"\nEntering an interactive chat loop with function-calling.\n"
|
||||
"Type 'exit' or Ctrl+C to quit.\n",
|
||||
fg="cyan",
|
||||
)
|
||||
|
||||
# Main chat loop
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
user_input = click.prompt("You", type=str)
|
||||
@@ -93,20 +100,14 @@ def run_chat():
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
break
|
||||
|
||||
# Append user message
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Invoke the LLM, passing tools and available_functions
|
||||
final_response = chat_llm.call(
|
||||
messages=messages,
|
||||
tools=[crew_tool_schema],
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
# Append assistant's reply
|
||||
messages.append({"role": "assistant", "content": final_response})
|
||||
|
||||
# Display assistant's reply
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@@ -165,7 +166,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
"""
|
||||
try:
|
||||
# Serialize 'messages' to JSON string before adding to kwargs
|
||||
kwargs['crew_chat_messages'] = json.dumps(messages)
|
||||
kwargs["crew_chat_messages"] = json.dumps(messages)
|
||||
|
||||
# Run the crew with the provided inputs
|
||||
crew_output = crew.kickoff(inputs=kwargs)
|
||||
@@ -184,7 +185,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||
"""
|
||||
Loads the crew by importing the crew class from the user's project.
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
||||
"""
|
||||
@@ -258,9 +259,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput
|
||||
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
||||
|
||||
return ChatInputs(
|
||||
crew_name=crew_name,
|
||||
crew_description=crew_description,
|
||||
inputs=input_fields
|
||||
crew_name=crew_name, crew_description=crew_description, inputs=input_fields
|
||||
)
|
||||
|
||||
|
||||
@@ -307,18 +306,31 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
|
||||
for task in crew.tasks:
|
||||
if f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output:
|
||||
if (
|
||||
f"{{{input_name}}}" in task.description
|
||||
or f"{{{input_name}}}" in task.expected_output
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
|
||||
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
for agent in crew.agents:
|
||||
if f"{{{input_name}}}" in agent.role or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory:
|
||||
if (
|
||||
f"{{{input_name}}}" in agent.role
|
||||
or f"{{{input_name}}}" in agent.goal
|
||||
or f"{{{input_name}}}" in agent.backstory
|
||||
):
|
||||
# Replace placeholders with input names
|
||||
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
||||
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
||||
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory)
|
||||
agent_backstory = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), agent.backstory
|
||||
)
|
||||
context_texts.append(f"Agent Role: {agent_role}")
|
||||
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||
context_texts.append(f"Agent Backstory: {agent_backstory}")
|
||||
@@ -357,8 +369,12 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
|
||||
for task in crew.tasks:
|
||||
# Replace placeholders with input names
|
||||
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
|
||||
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
|
||||
task_description = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.description
|
||||
)
|
||||
expected_output = placeholder_pattern.sub(
|
||||
lambda m: m.group(1), task.expected_output
|
||||
)
|
||||
context_texts.append(f"Task Description: {task_description}")
|
||||
context_texts.append(f"Expected Output: {expected_output}")
|
||||
for agent in crew.agents:
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def fetch_chat_llm() -> LLM:
|
||||
"""
|
||||
Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name),
|
||||
parsing its JSON stdout, and returning an LLM instance.
|
||||
|
||||
This expects the script "fetch_chat_llm" to print out JSON that represents the
|
||||
LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))).
|
||||
|
||||
Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError.
|
||||
"""
|
||||
|
||||
# You may change this command to match whatever's in your pyproject.toml [project.scripts].
|
||||
command = ["uv", "run", "fetch_chat_llm"]
|
||||
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.87.0" # Adjust as needed
|
||||
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# If old poetry-based setup is detected and version is below min_required_version
|
||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||
version.parse(crewai_version) < version.parse(min_required_version)
|
||||
):
|
||||
click.secho(
|
||||
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
|
||||
f"Please run `crewai update` to transition your pyproject.toml to use uv.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
# Initialize a reference to your LLM
|
||||
llm_instance = None
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
stdout_lines = result.stdout.strip().splitlines()
|
||||
|
||||
# Find the line that contains the JSON data
|
||||
json_line = next(
|
||||
(
|
||||
line
|
||||
for line in stdout_lines
|
||||
if line.startswith("{") and line.endswith("}")
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not json_line:
|
||||
raise RuntimeError(
|
||||
"No valid JSON output received from `fetch_chat_llm` command."
|
||||
)
|
||||
|
||||
try:
|
||||
llm_data = json.loads(json_line)
|
||||
llm_instance = LLM.from_dict(llm_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
|
||||
) from e
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"An unexpected error occurred while fetching chat LLM: {e}"
|
||||
) from e
|
||||
|
||||
if not llm_instance:
|
||||
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
|
||||
|
||||
return llm_instance
|
||||
@@ -19,22 +19,11 @@ def run():
|
||||
Usage example:
|
||||
uv run run_crew -- --topic="New Topic" --some_other_field="Value"
|
||||
"""
|
||||
# Default inputs
|
||||
inputs = {
|
||||
'topic': 'AI LLMs'
|
||||
# Add any other default fields here
|
||||
}
|
||||
|
||||
# 1) Gather overrides from sys.argv
|
||||
# sys.argv might look like: ['run_crew', '--topic=NewTopic']
|
||||
# But be aware that if you're calling "uv run run_crew", sys.argv might have
|
||||
# additional items. So we typically skip the first 1 or 2 items to get only overrides.
|
||||
overrides = parse_cli_overrides(sys.argv[1:])
|
||||
|
||||
# 2) Merge the overrides into defaults
|
||||
inputs.update(overrides)
|
||||
|
||||
# 3) Kick off the crew with final inputs
|
||||
|
||||
try:
|
||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
||||
except Exception as e:
|
||||
@@ -76,93 +65,3 @@ def test():
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while testing the crew: {e}")
|
||||
|
||||
def fetch_inputs():
|
||||
"""
|
||||
Command that gathers required placeholders/inputs from the Crew, then
|
||||
prints them as JSON to stdout so external scripts can parse them easily.
|
||||
"""
|
||||
try:
|
||||
crew = {{crew_name}}().crew()
|
||||
crew_inputs = crew.fetch_inputs()
|
||||
json_string = json.dumps(list(crew_inputs))
|
||||
print(json_string)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while fetching inputs: {e}")
|
||||
|
||||
def fetch_chat_llm():
|
||||
"""
|
||||
Command that fetches the 'chat_llm' property from the Crew,
|
||||
instantiates it via create_llm(),
|
||||
and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout.
|
||||
"""
|
||||
try:
|
||||
crew = {{crew_name}}().crew()
|
||||
raw_chat_llm = getattr(crew, "chat_llm", None)
|
||||
|
||||
if not raw_chat_llm:
|
||||
# If the crew doesn't have chat_llm, fallback to create_llm(None)
|
||||
final_llm = create_llm(None)
|
||||
else:
|
||||
# raw_chat_llm might be a dict, or an LLM, or something else
|
||||
final_llm = create_llm(raw_chat_llm)
|
||||
|
||||
if final_llm:
|
||||
# Print the final LLM as JSON, so fetch_chat_llm.py can parse it
|
||||
from crewai.llm import LLM # Import here to avoid circular references
|
||||
|
||||
# Make sure it's an instance of the LLM class:
|
||||
if isinstance(final_llm, LLM):
|
||||
print(json.dumps(final_llm.to_dict()))
|
||||
else:
|
||||
# If somehow it's not an LLM, try to interpret as a dict
|
||||
# or revert to an empty fallback
|
||||
if isinstance(final_llm, dict):
|
||||
print(json.dumps(final_llm))
|
||||
else:
|
||||
print(json.dumps({}))
|
||||
else:
|
||||
print(json.dumps({}))
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while fetching chat LLM: {e}")
|
||||
|
||||
# TODO: Talk to Joao about making using LLM calls to analyze the crew
|
||||
# and generate all of this information automatically
|
||||
def fetch_chat_inputs():
|
||||
"""
|
||||
Command that fetches the 'chat_inputs' property from the Crew,
|
||||
and prints it as JSON to stdout.
|
||||
"""
|
||||
try:
|
||||
crew = {{crew_name}}().crew()
|
||||
raw_chat_inputs = getattr(crew, "chat_inputs", None)
|
||||
|
||||
if raw_chat_inputs:
|
||||
# Convert to dictionary to print JSON
|
||||
print(json.dumps(raw_chat_inputs.model_dump()))
|
||||
else:
|
||||
# If crew.chat_inputs is None or empty, print an empty JSON
|
||||
print(json.dumps({}))
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while fetching chat inputs: {e}")
|
||||
|
||||
|
||||
def parse_cli_overrides(args_list) -> dict:
|
||||
"""
|
||||
Parse arguments in the form of --key=value from a list of CLI arguments.
|
||||
Return them as a dict. For example:
|
||||
['--topic=AI LLMs', '--username=John'] => {'topic': 'AI LLMs', 'username': 'John'}
|
||||
"""
|
||||
overrides = {}
|
||||
for arg in args_list:
|
||||
if arg.startswith("--"):
|
||||
# remove the leading --
|
||||
trimmed = arg[2:]
|
||||
if "=" in trimmed:
|
||||
key, val = trimmed.split("=", 1)
|
||||
overrides[key] = val
|
||||
else:
|
||||
# If someone passed something like --topic (no =),
|
||||
# either handle differently or ignore
|
||||
pass
|
||||
return overrides
|
||||
|
||||
@@ -5,7 +5,6 @@ import sys
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from importlib import resources
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -179,35 +178,6 @@ class LLM:
|
||||
"callbacks": self.callbacks,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "LLM":
|
||||
"""
|
||||
Create an LLM instance from a dict.
|
||||
We assume the dict has all relevant keys that match what's in the constructor.
|
||||
"""
|
||||
known_fields = {}
|
||||
known_fields["model"] = data.pop("model", None)
|
||||
known_fields["timeout"] = data.pop("timeout", None)
|
||||
known_fields["temperature"] = data.pop("temperature", None)
|
||||
known_fields["top_p"] = data.pop("top_p", None)
|
||||
known_fields["n"] = data.pop("n", None)
|
||||
known_fields["stop"] = data.pop("stop", None)
|
||||
known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None)
|
||||
known_fields["max_tokens"] = data.pop("max_tokens", None)
|
||||
known_fields["presence_penalty"] = data.pop("presence_penalty", None)
|
||||
known_fields["frequency_penalty"] = data.pop("frequency_penalty", None)
|
||||
known_fields["logit_bias"] = data.pop("logit_bias", None)
|
||||
known_fields["response_format"] = data.pop("response_format", None)
|
||||
known_fields["seed"] = data.pop("seed", None)
|
||||
known_fields["logprobs"] = data.pop("logprobs", None)
|
||||
known_fields["top_logprobs"] = data.pop("top_logprobs", None)
|
||||
known_fields["base_url"] = data.pop("base_url", None)
|
||||
known_fields["api_version"] = data.pop("api_version", None)
|
||||
known_fields["api_key"] = data.pop("api_key", None)
|
||||
known_fields["callbacks"] = data.pop("callbacks", None)
|
||||
|
||||
return cls(**known_fields, **data)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -14,10 +14,7 @@ class ChatInputField(BaseModel):
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the input field")
|
||||
description: str = Field(
|
||||
...,
|
||||
description="A short description of the input field",
|
||||
)
|
||||
description: str = Field(..., description="A short description of the input field")
|
||||
|
||||
|
||||
class ChatInputs(BaseModel):
|
||||
@@ -36,8 +33,7 @@ class ChatInputs(BaseModel):
|
||||
|
||||
crew_name: str = Field(..., description="The name of the crew")
|
||||
crew_description: str = Field(
|
||||
...,
|
||||
description="A description of the crew's purpose",
|
||||
..., description="A description of the crew's purpose"
|
||||
)
|
||||
inputs: List[ChatInputField] = Field(
|
||||
default_factory=list, description="A list of input fields for the crew"
|
||||
|
||||
96
tests/utilities/test_llm_utils.py
Normal file
96
tests/utilities/test_llm_utils.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from litellm.exceptions import BadRequestError
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
def test_create_llm_with_llm_instance():
|
||||
existing_llm = LLM(model="gpt-4o")
|
||||
llm = create_llm(llm_value=existing_llm)
|
||||
assert llm is existing_llm
|
||||
|
||||
|
||||
def test_create_llm_with_valid_model_string():
|
||||
llm = create_llm(llm_value="gpt-4o")
|
||||
assert isinstance(llm, LLM)
|
||||
assert llm.model == "gpt-4o"
|
||||
|
||||
|
||||
def test_create_llm_with_invalid_model_string():
|
||||
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||
llm = create_llm(llm_value="invalid-model")
|
||||
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||
|
||||
|
||||
def test_create_llm_with_unknown_object_missing_attributes():
|
||||
class UnknownObject:
|
||||
pass
|
||||
|
||||
unknown_obj = UnknownObject()
|
||||
llm = create_llm(llm_value=unknown_obj)
|
||||
|
||||
# Attempt to call the LLM and expect it to raise an error due to missing attributes
|
||||
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||
|
||||
|
||||
def test_create_llm_with_none_uses_default_model():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with patch("crewai.cli.constants.DEFAULT_LLM_MODEL", "gpt-4o"):
|
||||
llm = create_llm(llm_value=None)
|
||||
assert isinstance(llm, LLM)
|
||||
assert llm.model == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_create_llm_with_unknown_object():
|
||||
class UnknownObject:
|
||||
model_name = "gpt-4o"
|
||||
temperature = 0.7
|
||||
max_tokens = 1500
|
||||
|
||||
unknown_obj = UnknownObject()
|
||||
llm = create_llm(llm_value=unknown_obj)
|
||||
assert isinstance(llm, LLM)
|
||||
assert llm.model == "gpt-4o"
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 1500
|
||||
|
||||
|
||||
def test_create_llm_from_env_with_unaccepted_attributes():
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
|
||||
"AWS_ACCESS_KEY_ID": "fake-access-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "fake-secret-key",
|
||||
"AWS_REGION_NAME": "us-west-2",
|
||||
},
|
||||
):
|
||||
llm = create_llm(llm_value=None)
|
||||
assert isinstance(llm, LLM)
|
||||
assert llm.model == "gpt-3.5-turbo"
|
||||
assert not hasattr(llm, "AWS_ACCESS_KEY_ID")
|
||||
assert not hasattr(llm, "AWS_SECRET_ACCESS_KEY")
|
||||
assert not hasattr(llm, "AWS_REGION_NAME")
|
||||
|
||||
|
||||
def test_create_llm_with_partial_attributes():
|
||||
class PartialAttributes:
|
||||
model_name = "gpt-4o"
|
||||
# temperature is missing
|
||||
|
||||
obj = PartialAttributes()
|
||||
llm = create_llm(llm_value=obj)
|
||||
assert isinstance(llm, LLM)
|
||||
assert llm.model == "gpt-4o"
|
||||
assert llm.temperature is None # Should handle missing attributes gracefully
|
||||
|
||||
|
||||
def test_create_llm_with_invalid_type():
|
||||
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||
llm = create_llm(llm_value=42)
|
||||
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||
Reference in New Issue
Block a user