mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-14 02:28:30 +00:00
high level chat working
This commit is contained in:
@@ -136,10 +136,8 @@ class Agent(BaseAgent):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
self.llm = create_llm(self.llm, default_model="gpt-4o-mini")
|
||||
self.function_calling_llm = create_llm(
|
||||
self.function_calling_llm, default_model="gpt-4o-mini"
|
||||
)
|
||||
self.llm = create_llm(self.llm)
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
|
||||
@@ -158,6 +158,8 @@ MODELS = {
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_LLM_MODEL = "gpt-4o-mini"
|
||||
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import json
|
||||
import subprocess
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
||||
@@ -20,20 +16,15 @@ def run_chat():
|
||||
# 1) Fetch CrewInputs
|
||||
click.secho("Gathering crew inputs via `fetch_crew_inputs()`...", fg="cyan")
|
||||
try:
|
||||
crew_inputs: ChatInputs = fetch_crew_inputs()
|
||||
crew_inputs = fetch_crew_inputs()
|
||||
click.echo(f"CrewInputs: {crew_inputs}")
|
||||
if not crew_inputs:
|
||||
click.secho("Error: Failed to fetch crew inputs. Exiting.", fg="red")
|
||||
return
|
||||
except Exception as e:
|
||||
click.secho(f"Error fetching crew inputs: {e}", fg="red")
|
||||
return
|
||||
|
||||
# Check for mandatory fields
|
||||
if not crew_inputs.crew_name:
|
||||
click.secho("Error: Crew name is missing. Exiting.", fg="red")
|
||||
return
|
||||
|
||||
if not crew_inputs.crew_description:
|
||||
click.secho("Error: Crew description is missing. Exiting.", fg="red")
|
||||
return
|
||||
|
||||
# 2) Generate a tool schema from the crew inputs
|
||||
crew_tool_schema = generate_crew_tool_schema(crew_inputs)
|
||||
|
||||
@@ -85,7 +76,7 @@ def run_chat():
|
||||
# 6) Main chat loop
|
||||
while True:
|
||||
try:
|
||||
user_input = click.prompt("You: ", type=str)
|
||||
user_input = click.prompt("You", type=str)
|
||||
if user_input.strip().lower() in ["exit", "quit"]:
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
break
|
||||
|
||||
@@ -42,29 +42,39 @@ def fetch_chat_llm() -> LLM:
|
||||
llm_instance = None
|
||||
|
||||
try:
|
||||
# 1) Run the subprocess
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
stdout_str = result.stdout.strip()
|
||||
stdout_lines = result.stdout.strip().splitlines()
|
||||
|
||||
# 2) Attempt to parse stdout as JSON
|
||||
if stdout_str:
|
||||
try:
|
||||
llm_data = json.loads(stdout_str)
|
||||
llm_instance = LLM.from_dict(llm_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unable to parse JSON from `fetch_chat_llm` output: {e}"
|
||||
) from e
|
||||
# Find the line that contains the JSON data
|
||||
json_line = next(
|
||||
(
|
||||
line
|
||||
for line in stdout_lines
|
||||
if line.startswith("{") and line.endswith("}")
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not json_line:
|
||||
raise RuntimeError(
|
||||
"No valid JSON output received from `fetch_chat_llm` command."
|
||||
)
|
||||
|
||||
try:
|
||||
llm_data = json.loads(json_line)
|
||||
llm_instance = LLM.from_dict(llm_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
|
||||
) from e
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Subprocess error means the script failed
|
||||
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"An unexpected error occurred while fetching chat LLM: {e}"
|
||||
) from e
|
||||
|
||||
# Verify that we have a valid LLM
|
||||
if not llm_instance:
|
||||
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
@@ -9,7 +10,7 @@ from crewai.cli.version import get_crewai_version
|
||||
from crewai.types.crew_chat import ChatInputs
|
||||
|
||||
|
||||
def fetch_crew_inputs() -> ChatInputs:
|
||||
def fetch_crew_inputs() -> Optional[ChatInputs]:
|
||||
"""
|
||||
Fetch the crew's ChatInputs (a structure containing crew_description and input fields)
|
||||
by running "uv run fetch_chat_inputs", which prints JSON representing a ChatInputs object.
|
||||
@@ -37,22 +38,37 @@ def fetch_crew_inputs() -> ChatInputs:
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
stdout_str = result.stdout.strip()
|
||||
stdout_lines = result.stdout.strip().splitlines()
|
||||
|
||||
if not stdout_str:
|
||||
return ChatInputs(crew_name=crew_name)
|
||||
# Find the line that contains the JSON data
|
||||
json_line = next(
|
||||
(
|
||||
line
|
||||
for line in stdout_lines
|
||||
if line.startswith("{") and line.endswith("}")
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not json_line:
|
||||
click.echo(
|
||||
"No valid JSON output received from `fetch_chat_inputs` command.",
|
||||
err=True,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
raw_data = json.loads(stdout_str)
|
||||
raw_data = json.loads(json_line)
|
||||
chat_inputs = ChatInputs(**raw_data)
|
||||
if crew_name:
|
||||
chat_inputs.crew_name = crew_name
|
||||
return chat_inputs
|
||||
except json.JSONDecodeError as e:
|
||||
click.echo(
|
||||
f"Unable to parse JSON from `fetch_chat_inputs` output: {e}", err=True
|
||||
f"Unable to parse JSON from `fetch_chat_inputs` output: {e}\nOutput: {repr(json_line)}",
|
||||
err=True,
|
||||
)
|
||||
return ChatInputs(crew_name=crew_name)
|
||||
return None
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while fetching chat inputs: {e}", err=True)
|
||||
@@ -67,4 +83,4 @@ def fetch_crew_inputs() -> ChatInputs:
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
|
||||
return ChatInputs(crew_name=crew_name)
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@ run_crew = "{{folder_name}}.main:run"
|
||||
train = "{{folder_name}}.main:train"
|
||||
replay = "{{folder_name}}.main:replay"
|
||||
test = "{{folder_name}}.main:test"
|
||||
fetch_inputs = "{{folder_name}}.main:fetch_inputs"
|
||||
fetch_chat_inputs = "{{folder_name}}.main:fetch_chat_inputs"
|
||||
fetch_chat_llm = "{{folder_name}}.main:fetch_chat_llm"
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -7,13 +7,17 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
# Load environment variables from .env file
|
||||
import litellm
|
||||
from litellm import ModelResponse, get_supported_openai_params
|
||||
from dotenv import load_dotenv
|
||||
from litellm import get_supported_openai_params
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
@@ -111,7 +115,6 @@ class LLM:
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
@@ -133,11 +136,9 @@ class LLM:
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
|
||||
# For safety, we disable passing init params to next calls
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
@@ -166,7 +167,6 @@ class LLM:
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"callbacks": self.callbacks,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -248,12 +248,13 @@ class LLM:
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
"tools": tools, # pass the tool schema
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
# remove None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
print(f"Params: {params}")
|
||||
|
||||
response = litellm.completion(**params)
|
||||
response_message = response.choices[0].message
|
||||
text_response = response_message.content or ""
|
||||
@@ -283,7 +284,7 @@ class LLM:
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"role": "function",
|
||||
"name": function_name,
|
||||
"content": str(result),
|
||||
}
|
||||
|
||||
@@ -13,8 +13,11 @@ class ChatInputField(BaseModel):
|
||||
}
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
name: str = Field(..., description="The name of the input field")
|
||||
description: str = Field(
|
||||
...,
|
||||
description="A short description of the input field",
|
||||
)
|
||||
|
||||
|
||||
class ChatInputs(BaseModel):
|
||||
@@ -31,6 +34,11 @@ class ChatInputs(BaseModel):
|
||||
}
|
||||
"""
|
||||
|
||||
crew_name: Optional[str] = Field(default="Crew")
|
||||
crew_description: Optional[str] = None
|
||||
inputs: List[ChatInputField] = Field(default_factory=list)
|
||||
crew_name: str = Field(..., description="The name of the crew")
|
||||
crew_description: str = Field(
|
||||
...,
|
||||
description="A description of the crew's purpose",
|
||||
)
|
||||
inputs: List[ChatInputField] = Field(
|
||||
default_factory=list, description="A list of input fields for the crew"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.llm import LLM
|
||||
@@ -11,7 +11,6 @@ from crewai.llm import LLM
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
default_model: str = "gpt-4o-mini",
|
||||
) -> Optional[LLM]:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
@@ -45,7 +44,7 @@ def create_llm(
|
||||
|
||||
# 3) If llm_value is None, parse environment variables or use default
|
||||
if llm_value is None:
|
||||
return _llm_via_environment_or_fallback(default_model)
|
||||
return _llm_via_environment_or_fallback()
|
||||
|
||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object (like a config)
|
||||
# e.g. follow the approach used in agent.py
|
||||
@@ -104,17 +103,17 @@ def create_chat_llm(default_model: str = "gpt-4") -> Optional[LLM]:
|
||||
)
|
||||
|
||||
# After checks, simply call create_llm with None (meaning "use env or fallback"):
|
||||
return create_llm(None, default_model=default_model)
|
||||
return create_llm(None)
|
||||
|
||||
|
||||
def _llm_via_environment_or_fallback(
|
||||
default_model: str = "gpt-4o-mini",
|
||||
) -> Optional[LLM]:
|
||||
def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
"""
|
||||
Helper function: if llm_value is None, we load environment variables or fallback default model.
|
||||
"""
|
||||
model_name = (
|
||||
os.environ.get("OPENAI_MODEL_NAME") or os.environ.get("MODEL") or default_model
|
||||
os.environ.get("OPENAI_MODEL_NAME")
|
||||
or os.environ.get("MODEL")
|
||||
or DEFAULT_LLM_MODEL
|
||||
)
|
||||
llm_params = {"model": model_name}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user