high level chat working

This commit is contained in:
Brandon Hancock
2024-12-27 11:21:40 -05:00
parent 2bf5b15f1e
commit 2f882d68ad
9 changed files with 86 additions and 61 deletions

View File

@@ -136,10 +136,8 @@ class Agent(BaseAgent):
self._set_knowledge() self._set_knowledge()
self.agent_ops_agent_name = self.role self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm, default_model="gpt-4o-mini") self.llm = create_llm(self.llm)
self.function_calling_llm = create_llm( self.function_calling_llm = create_llm(self.function_calling_llm)
self.function_calling_llm, default_model="gpt-4o-mini"
)
if not self.agent_executor: if not self.agent_executor:
self._setup_agent_executor() self._setup_agent_executor()

View File

@@ -158,6 +158,8 @@ MODELS = {
], ],
} }
DEFAULT_LLM_MODEL = "gpt-4o-mini"
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"

View File

@@ -1,7 +1,3 @@
import json
import subprocess
from typing import cast
import click import click
from crewai.cli.fetch_chat_llm import fetch_chat_llm from crewai.cli.fetch_chat_llm import fetch_chat_llm
@@ -20,20 +16,15 @@ def run_chat():
# 1) Fetch CrewInputs # 1) Fetch CrewInputs
click.secho("Gathering crew inputs via `fetch_crew_inputs()`...", fg="cyan") click.secho("Gathering crew inputs via `fetch_crew_inputs()`...", fg="cyan")
try: try:
crew_inputs: ChatInputs = fetch_crew_inputs() crew_inputs = fetch_crew_inputs()
click.echo(f"CrewInputs: {crew_inputs}")
if not crew_inputs:
click.secho("Error: Failed to fetch crew inputs. Exiting.", fg="red")
return
except Exception as e: except Exception as e:
click.secho(f"Error fetching crew inputs: {e}", fg="red") click.secho(f"Error fetching crew inputs: {e}", fg="red")
return return
# Check for mandatory fields
if not crew_inputs.crew_name:
click.secho("Error: Crew name is missing. Exiting.", fg="red")
return
if not crew_inputs.crew_description:
click.secho("Error: Crew description is missing. Exiting.", fg="red")
return
# 2) Generate a tool schema from the crew inputs # 2) Generate a tool schema from the crew inputs
crew_tool_schema = generate_crew_tool_schema(crew_inputs) crew_tool_schema = generate_crew_tool_schema(crew_inputs)
@@ -85,7 +76,7 @@ def run_chat():
# 6) Main chat loop # 6) Main chat loop
while True: while True:
try: try:
user_input = click.prompt("You: ", type=str) user_input = click.prompt("You", type=str)
if user_input.strip().lower() in ["exit", "quit"]: if user_input.strip().lower() in ["exit", "quit"]:
click.echo("Exiting chat. Goodbye!") click.echo("Exiting chat. Goodbye!")
break break

View File

@@ -42,29 +42,39 @@ def fetch_chat_llm() -> LLM:
llm_instance = None llm_instance = None
try: try:
# 1) Run the subprocess
result = subprocess.run(command, capture_output=True, text=True, check=True) result = subprocess.run(command, capture_output=True, text=True, check=True)
stdout_str = result.stdout.strip() stdout_lines = result.stdout.strip().splitlines()
# 2) Attempt to parse stdout as JSON # Find the line that contains the JSON data
if stdout_str: json_line = next(
try: (
llm_data = json.loads(stdout_str) line
llm_instance = LLM.from_dict(llm_data) for line in stdout_lines
except json.JSONDecodeError as e: if line.startswith("{") and line.endswith("}")
raise RuntimeError( ),
f"Unable to parse JSON from `fetch_chat_llm` output: {e}" None,
) from e )
if not json_line:
raise RuntimeError(
"No valid JSON output received from `fetch_chat_llm` command."
)
try:
llm_data = json.loads(json_line)
llm_instance = LLM.from_dict(llm_data)
except json.JSONDecodeError as e:
raise RuntimeError(
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
) from e
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
# Subprocess error means the script failed
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"An unexpected error occurred while fetching chat LLM: {e}" f"An unexpected error occurred while fetching chat LLM: {e}"
) from e ) from e
# Verify that we have a valid LLM
if not llm_instance: if not llm_instance:
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.") raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")

View File

@@ -1,5 +1,6 @@
import json import json
import subprocess import subprocess
from typing import Optional
import click import click
from packaging import version from packaging import version
@@ -9,7 +10,7 @@ from crewai.cli.version import get_crewai_version
from crewai.types.crew_chat import ChatInputs from crewai.types.crew_chat import ChatInputs
def fetch_crew_inputs() -> ChatInputs: def fetch_crew_inputs() -> Optional[ChatInputs]:
""" """
Fetch the crew's ChatInputs (a structure containing crew_description and input fields) Fetch the crew's ChatInputs (a structure containing crew_description and input fields)
by running "uv run fetch_chat_inputs", which prints JSON representing a ChatInputs object. by running "uv run fetch_chat_inputs", which prints JSON representing a ChatInputs object.
@@ -37,22 +38,37 @@ def fetch_crew_inputs() -> ChatInputs:
try: try:
result = subprocess.run(command, capture_output=True, text=True, check=True) result = subprocess.run(command, capture_output=True, text=True, check=True)
stdout_str = result.stdout.strip() stdout_lines = result.stdout.strip().splitlines()
if not stdout_str: # Find the line that contains the JSON data
return ChatInputs(crew_name=crew_name) json_line = next(
(
line
for line in stdout_lines
if line.startswith("{") and line.endswith("}")
),
None,
)
if not json_line:
click.echo(
"No valid JSON output received from `fetch_chat_inputs` command.",
err=True,
)
return None
try: try:
raw_data = json.loads(stdout_str) raw_data = json.loads(json_line)
chat_inputs = ChatInputs(**raw_data) chat_inputs = ChatInputs(**raw_data)
if crew_name: if crew_name:
chat_inputs.crew_name = crew_name chat_inputs.crew_name = crew_name
return chat_inputs return chat_inputs
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
click.echo( click.echo(
f"Unable to parse JSON from `fetch_chat_inputs` output: {e}", err=True f"Unable to parse JSON from `fetch_chat_inputs` output: {e}\nOutput: {repr(json_line)}",
err=True,
) )
return ChatInputs(crew_name=crew_name) return None
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while fetching chat inputs: {e}", err=True) click.echo(f"An error occurred while fetching chat inputs: {e}", err=True)
@@ -67,4 +83,4 @@ def fetch_crew_inputs() -> ChatInputs:
except Exception as e: except Exception as e:
click.echo(f"An unexpected error occurred: {e}", err=True) click.echo(f"An unexpected error occurred: {e}", err=True)
return ChatInputs(crew_name=crew_name) return None

View File

@@ -14,7 +14,7 @@ run_crew = "{{folder_name}}.main:run"
train = "{{folder_name}}.main:train" train = "{{folder_name}}.main:train"
replay = "{{folder_name}}.main:replay" replay = "{{folder_name}}.main:replay"
test = "{{folder_name}}.main:test" test = "{{folder_name}}.main:test"
fetch_inputs = "{{folder_name}}.main:fetch_inputs" fetch_chat_inputs = "{{folder_name}}.main:fetch_chat_inputs"
fetch_chat_llm = "{{folder_name}}.main:fetch_chat_llm" fetch_chat_llm = "{{folder_name}}.main:fetch_chat_llm"
[build-system] [build-system]

View File

@@ -7,13 +7,17 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
# Load environment variables from .env file
import litellm import litellm
from litellm import ModelResponse, get_supported_openai_params from dotenv import load_dotenv
from litellm import get_supported_openai_params
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException, LLMContextLengthExceededException,
) )
load_dotenv()
class FilteredStream: class FilteredStream:
def __init__(self, original_stream): def __init__(self, original_stream):
@@ -111,7 +115,6 @@ class LLM:
api_version: Optional[str] = None, api_version: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
callbacks: List[Any] = [], callbacks: List[Any] = [],
**kwargs,
): ):
self.model = model self.model = model
self.timeout = timeout self.timeout = timeout
@@ -133,11 +136,9 @@ class LLM:
self.api_key = api_key self.api_key = api_key
self.callbacks = callbacks self.callbacks = callbacks
self.context_window_size = 0 self.context_window_size = 0
self.kwargs = kwargs
# For safety, we disable passing init params to next calls # For safety, we disable passing init params to next calls
litellm.drop_params = True litellm.drop_params = True
litellm.set_verbose = False
self.set_callbacks(callbacks) self.set_callbacks(callbacks)
self.set_env_callbacks() self.set_env_callbacks()
@@ -166,7 +167,6 @@ class LLM:
"api_version": self.api_version, "api_version": self.api_version,
"api_key": self.api_key, "api_key": self.api_key,
"callbacks": self.callbacks, "callbacks": self.callbacks,
"kwargs": self.kwargs,
} }
@classmethod @classmethod
@@ -248,12 +248,13 @@ class LLM:
"api_key": self.api_key, "api_key": self.api_key,
"stream": False, "stream": False,
"tools": tools, # pass the tool schema "tools": tools, # pass the tool schema
**self.kwargs,
} }
# remove None values # remove None values
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
print(f"Params: {params}")
response = litellm.completion(**params) response = litellm.completion(**params)
response_message = response.choices[0].message response_message = response.choices[0].message
text_response = response_message.content or "" text_response = response_message.content or ""
@@ -283,7 +284,7 @@ class LLM:
messages.append( messages.append(
{ {
"tool_call_id": tool_call.id, "tool_call_id": tool_call.id,
"role": "tool", "role": "function",
"name": function_name, "name": function_name,
"content": str(result), "content": str(result),
} }

View File

@@ -13,8 +13,11 @@ class ChatInputField(BaseModel):
} }
""" """
name: str name: str = Field(..., description="The name of the input field")
description: Optional[str] = None description: str = Field(
...,
description="A short description of the input field",
)
class ChatInputs(BaseModel): class ChatInputs(BaseModel):
@@ -31,6 +34,11 @@ class ChatInputs(BaseModel):
} }
""" """
crew_name: Optional[str] = Field(default="Crew") crew_name: str = Field(..., description="The name of the crew")
crew_description: Optional[str] = None crew_description: str = Field(
inputs: List[ChatInputField] = Field(default_factory=list) ...,
description="A description of the crew's purpose",
)
inputs: List[ChatInputField] = Field(
default_factory=list, description="A list of input fields for the crew"
)

View File

@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Union
from packaging import version from packaging import version
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.cli.utils import read_toml from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version from crewai.cli.version import get_crewai_version
from crewai.llm import LLM from crewai.llm import LLM
@@ -11,7 +11,6 @@ from crewai.llm import LLM
def create_llm( def create_llm(
llm_value: Union[str, LLM, Any, None] = None, llm_value: Union[str, LLM, Any, None] = None,
default_model: str = "gpt-4o-mini",
) -> Optional[LLM]: ) -> Optional[LLM]:
""" """
Creates or returns an LLM instance based on the given llm_value. Creates or returns an LLM instance based on the given llm_value.
@@ -45,7 +44,7 @@ def create_llm(
# 3) If llm_value is None, parse environment variables or use default # 3) If llm_value is None, parse environment variables or use default
if llm_value is None: if llm_value is None:
return _llm_via_environment_or_fallback(default_model) return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object (like a config) # 4) Otherwise, attempt to extract relevant attributes from an unknown object (like a config)
# e.g. follow the approach used in agent.py # e.g. follow the approach used in agent.py
@@ -104,17 +103,17 @@ def create_chat_llm(default_model: str = "gpt-4") -> Optional[LLM]:
) )
# After checks, simply call create_llm with None (meaning "use env or fallback"): # After checks, simply call create_llm with None (meaning "use env or fallback"):
return create_llm(None, default_model=default_model) return create_llm(None)
def _llm_via_environment_or_fallback( def _llm_via_environment_or_fallback() -> Optional[LLM]:
default_model: str = "gpt-4o-mini",
) -> Optional[LLM]:
""" """
Helper function: if llm_value is None, we load environment variables or fallback default model. Helper function: if llm_value is None, we load environment variables or fallback default model.
""" """
model_name = ( model_name = (
os.environ.get("OPENAI_MODEL_NAME") or os.environ.get("MODEL") or default_model os.environ.get("OPENAI_MODEL_NAME")
or os.environ.get("MODEL")
or DEFAULT_LLM_MODEL
) )
llm_params = {"model": model_name} llm_params = {"model": model_name}