mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
worked on foundation for new conversational crews. Now going to work on chatting.
This commit is contained in:
@@ -20,6 +20,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -134,89 +135,11 @@ class Agent(BaseAgent):
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
unaccepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
# If it's a string, create an LLM instance
|
||||
self.llm = LLM(model=self.llm)
|
||||
elif isinstance(self.llm, LLM):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
elif self.llm is None:
|
||||
# Determine the model name from environment variables or use default
|
||||
model_name = (
|
||||
os.environ.get("OPENAI_MODEL_NAME")
|
||||
or os.environ.get("MODEL")
|
||||
or "gpt-4o-mini"
|
||||
)
|
||||
llm_params = {"model": model_name}
|
||||
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
|
||||
"OPENAI_BASE_URL"
|
||||
)
|
||||
if api_base:
|
||||
llm_params["base_url"] = api_base
|
||||
|
||||
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
|
||||
|
||||
# Iterate over all environment variables to find matching API keys or use defaults
|
||||
for provider, env_vars in ENV_VARS.items():
|
||||
if provider == set_provider:
|
||||
for env_var in env_vars:
|
||||
# Check if the environment variable is set
|
||||
key_name = env_var.get("key_name")
|
||||
if key_name and key_name not in unaccepted_attributes:
|
||||
env_value = os.environ.get(key_name)
|
||||
if env_value:
|
||||
key_name = key_name.lower()
|
||||
for pattern in LITELLM_PARAMS:
|
||||
if pattern in key_name:
|
||||
key_name = pattern
|
||||
break
|
||||
llm_params[key_name] = env_value
|
||||
# Check for default values if the environment variable is not set
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
else:
|
||||
# For any other type, attempt to extract relevant attributes
|
||||
llm_params = {
|
||||
"model": getattr(self.llm, "model_name", None)
|
||||
or getattr(self.llm, "deployment_name", None)
|
||||
or str(self.llm),
|
||||
"temperature": getattr(self.llm, "temperature", None),
|
||||
"max_tokens": getattr(self.llm, "max_tokens", None),
|
||||
"logprobs": getattr(self.llm, "logprobs", None),
|
||||
"timeout": getattr(self.llm, "timeout", None),
|
||||
"max_retries": getattr(self.llm, "max_retries", None),
|
||||
"api_key": getattr(self.llm, "api_key", None),
|
||||
"base_url": getattr(self.llm, "base_url", None),
|
||||
"organization": getattr(self.llm, "organization", None),
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
self.llm = LLM(**llm_params)
|
||||
|
||||
# Similar handling for function_calling_llm
|
||||
if self.function_calling_llm:
|
||||
if isinstance(self.function_calling_llm, str):
|
||||
self.function_calling_llm = LLM(model=self.function_calling_llm)
|
||||
elif not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = LLM(
|
||||
model=getattr(self.function_calling_llm, "model_name", None)
|
||||
or getattr(self.function_calling_llm, "deployment_name", None)
|
||||
or str(self.function_calling_llm)
|
||||
)
|
||||
self.llm = create_llm(self.llm, default_model="gpt-4o-mini")
|
||||
self.function_calling_llm = create_llm(
|
||||
self.function_calling_llm, default_model="gpt-4o-mini"
|
||||
)
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import os
|
||||
from importlib.metadata import version as get_version
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import click
|
||||
|
||||
from crewai import (
|
||||
Crew, # We'll assume a direct import of the Crew class or import from .somewhere
|
||||
)
|
||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||
from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
@@ -13,6 +18,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
from .authentication.main import AuthenticationCommand
|
||||
from .deploy.main import DeployCommand
|
||||
from .evaluate_crew import evaluate_crew
|
||||
from .fetch_crew_inputs import fetch_crew_inputs
|
||||
from .install_crew import install_crew
|
||||
from .kickoff_flow import kickoff_flow
|
||||
from .plot_flow import plot_flow
|
||||
@@ -342,5 +348,127 @@ def flow_add_crew(crew_name):
|
||||
add_crew_to_flow(crew_name)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def chat():
|
||||
"""
|
||||
Start a conversation with the Crew, collecting user-supplied inputs
|
||||
only if needed. This is a demo of a 'chat' flow.
|
||||
"""
|
||||
click.secho("Welcome to CrewAI Chat!", fg="green")
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# 1) Attempt to fetch Crew inputs
|
||||
# --------------------------------------------------------------------------
|
||||
click.secho("Gathering crew inputs via `fetch_crew_inputs()`...", fg="cyan")
|
||||
try:
|
||||
crew_inputs = fetch_crew_inputs()
|
||||
except Exception as e:
|
||||
# If an error occurs, we print it and halt.
|
||||
click.secho(f"Error fetching crew inputs: {e}", fg="red")
|
||||
return
|
||||
|
||||
# If crew_inputs is empty, that's fine. We'll proceed anyway.
|
||||
click.secho(
|
||||
f"Found placeholders (possibly empty): {sorted(list(crew_inputs))}", fg="yellow"
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# 2) Retrieve the Chat LLM
|
||||
# --------------------------------------------------------------------------
|
||||
click.secho("Fetching the Chat LLM...", fg="cyan")
|
||||
try:
|
||||
chat_llm = fetch_chat_llm()
|
||||
except Exception as e:
|
||||
click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red")
|
||||
return
|
||||
|
||||
if not chat_llm:
|
||||
click.secho("No valid Chat LLM returned. Exiting.", fg="red")
|
||||
return
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# 3) Simple chat loop (demo)
|
||||
# --------------------------------------------------------------------------
|
||||
click.secho(
|
||||
"\nEntering interactive chat loop. Type 'exit' or Ctrl+C to quit.\n", fg="cyan"
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = click.prompt("You", type=str)
|
||||
if user_input.strip().lower() in ["exit", "quit"]:
|
||||
click.echo("Exiting chat. Goodbye!")
|
||||
break
|
||||
|
||||
# For demonstration, we'll call the LLM directly on the user input:
|
||||
response = chat_llm.generate(user_input)
|
||||
click.secho(f"\nAI: {response}\n", fg="green")
|
||||
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
click.echo("\nExiting chat. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
click.secho(f"Error occurred while generating chat response: {e}", fg="red")
|
||||
break
|
||||
|
||||
|
||||
def load_crew_and_find_inputs(file_path: str) -> Tuple[Optional[Crew], set]:
|
||||
"""
|
||||
Attempt to load a Crew from the provided file path or default location.
|
||||
Then gather placeholders from tasks. Returns (crew, set_of_placeholders).
|
||||
"""
|
||||
crew = None
|
||||
placeholders_found = set()
|
||||
|
||||
# 1) If file_path is not provided, attempt to detect the default crew config.
|
||||
if not file_path:
|
||||
# This is naive detection logic.
|
||||
# A real implementation might search typical locations like ./
|
||||
# or src/<project_name>/config/ for a crew configuration.
|
||||
default_candidate = "crew.yaml"
|
||||
if os.path.exists(default_candidate):
|
||||
file_path = default_candidate
|
||||
|
||||
# 2) Try to load the crew from file if file_path exists
|
||||
if file_path and os.path.isfile(file_path):
|
||||
# Pseudocode for loading a crew from file—may vary depending on how the user’s config is stored
|
||||
try:
|
||||
# For demonstration, we do something like:
|
||||
# with open(file_path, "r") as f:
|
||||
# content = f.read()
|
||||
# crew_data = parse_yaml_crew(content)
|
||||
# crew = Crew(**crew_data)
|
||||
# Placeholder logic below:
|
||||
crew = Crew(name="ExampleCrew")
|
||||
except Exception as e:
|
||||
click.secho(f"Error loading Crew from {file_path}: {e}", fg="red")
|
||||
raise e
|
||||
|
||||
if crew:
|
||||
# 3) Inspect crew tasks for placeholders
|
||||
# For each Task, we gather placeholders used in description/expected_output
|
||||
for task in crew.tasks:
|
||||
placeholders_in_desc = extract_placeholders(task.description)
|
||||
placeholders_in_out = extract_placeholders(task.expected_output)
|
||||
placeholders_found.update(placeholders_in_desc)
|
||||
placeholders_found.update(placeholders_in_out)
|
||||
|
||||
return crew, placeholders_found
|
||||
|
||||
|
||||
def extract_placeholders(text: str) -> set:
|
||||
"""
|
||||
Given a string, find all placeholders of the form {something} that might be used for input interpolation.
|
||||
This is a naive example—actual logic might do advanced parsing to avoid curly braces used in JSON.
|
||||
"""
|
||||
import re
|
||||
|
||||
if not text:
|
||||
return set()
|
||||
pattern = r"\{([a-zA-Z0-9_]+)\}"
|
||||
matches = re.findall(pattern, text)
|
||||
return set(matches)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
71
src/crewai/cli/fetch_chat_llm.py
Normal file
71
src/crewai/cli/fetch_chat_llm.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def fetch_chat_llm() -> LLM:
|
||||
"""
|
||||
Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name),
|
||||
parsing its JSON stdout, and returning an LLM instance.
|
||||
|
||||
This expects the script "fetch_chat_llm" to print out JSON that represents the
|
||||
LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))).
|
||||
|
||||
Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError.
|
||||
"""
|
||||
|
||||
# You may change this command to match whatever's in your pyproject.toml [project.scripts].
|
||||
command = ["uv", "run", "fetch_chat_llm"]
|
||||
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.87.0" # Adjust as needed
|
||||
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# If old poetry-based setup is detected and version is below min_required_version
|
||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||
version.parse(crewai_version) < version.parse(min_required_version)
|
||||
):
|
||||
click.secho(
|
||||
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
|
||||
f"Please run `crewai update` to transition your pyproject.toml to use uv.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
# Initialize a reference to your LLM
|
||||
llm_instance = None
|
||||
|
||||
try:
|
||||
# 1) Run the subprocess
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
stdout_str = result.stdout.strip()
|
||||
|
||||
# 2) Attempt to parse stdout as JSON
|
||||
if stdout_str:
|
||||
try:
|
||||
llm_data = json.loads(stdout_str)
|
||||
llm_instance = LLM.from_dict(llm_data)
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unable to parse JSON from `fetch_chat_llm` output: {e}"
|
||||
) from e
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Subprocess error means the script failed
|
||||
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"An unexpected error occurred while fetching chat LLM: {e}"
|
||||
) from e
|
||||
|
||||
# Verify that we have a valid LLM
|
||||
if not llm_instance:
|
||||
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
|
||||
|
||||
return llm_instance
|
||||
60
src/crewai/cli/fetch_crew_inputs.py
Normal file
60
src/crewai/cli/fetch_crew_inputs.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
def fetch_crew_inputs() -> set[str]:
|
||||
"""
|
||||
Fetch placeholders/inputs for the crew by running 'uv run fetch_inputs'.
|
||||
This captures stdout (which is now expected to be JSON),
|
||||
parses it into a Python list/set, and returns it.
|
||||
"""
|
||||
command = ["uv", "run", "fetch_inputs"]
|
||||
placeholders = set()
|
||||
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.87.0" # TODO: Update to latest version when cut
|
||||
|
||||
pyproject_data = read_toml()
|
||||
|
||||
# Check for old poetry-based setups
|
||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||
version.parse(crewai_version) < version.parse(min_required_version)
|
||||
):
|
||||
click.secho(
|
||||
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
|
||||
f"Please run `crewai update` to update your pyproject.toml to use uv.",
|
||||
fg="red",
|
||||
)
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||
# The entire stdout should now be a JSON array of placeholders (e.g. ["topic","username",...])
|
||||
stdout_str = result.stdout.strip()
|
||||
if stdout_str:
|
||||
try:
|
||||
placeholders_list = json.loads(stdout_str)
|
||||
if isinstance(placeholders_list, list):
|
||||
placeholders = set(placeholders_list)
|
||||
except json.JSONDecodeError:
|
||||
click.echo("Unable to parse JSON from `fetch_inputs` output.", err=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while fetching inputs: {e}", err=True)
|
||||
click.echo(e.output, err=True, nl=True)
|
||||
|
||||
if pyproject_data.get("tool", {}).get("poetry"):
|
||||
click.secho(
|
||||
"It's possible that you are using an old version of crewAI that uses poetry.\n"
|
||||
"Please run `crewai update` to update your pyproject.toml to use uv.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
|
||||
return placeholders
|
||||
@@ -1,8 +1,10 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from {{folder_name}}.crew import {{crew_name}}
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
@@ -55,4 +57,53 @@ def test():
|
||||
{{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while replaying the crew: {e}")
|
||||
raise Exception(f"An error occurred while testing the crew: {e}")
|
||||
|
||||
def fetch_inputs():
|
||||
"""
|
||||
Command that gathers required placeholders/inputs from the Crew, then
|
||||
prints them as JSON to stdout so external scripts can parse them easily.
|
||||
"""
|
||||
try:
|
||||
crew = {{crew_name}}().crew()
|
||||
crew_inputs = crew.fetch_inputs()
|
||||
json_string = json.dumps(list(crew_inputs))
|
||||
print(json_string)
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while fetching inputs: {e}")
|
||||
|
||||
def fetch_chat_llm():
|
||||
"""
|
||||
Command that fetches the 'chat_llm' property from the Crew,
|
||||
instantiates it via create_llm(),
|
||||
and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout.
|
||||
"""
|
||||
try:
|
||||
crew = {{crew_name}}().crew()
|
||||
raw_chat_llm = getattr(crew, "chat_llm", None)
|
||||
|
||||
if not raw_chat_llm:
|
||||
# If the crew doesn't have chat_llm, fallback to create_llm(None)
|
||||
final_llm = create_llm(None)
|
||||
else:
|
||||
# raw_chat_llm might be a dict, or an LLM, or something else
|
||||
final_llm = create_llm(raw_chat_llm)
|
||||
|
||||
if final_llm:
|
||||
# Print the final LLM as JSON, so fetch_chat_llm.py can parse it
|
||||
from crewai.llm import LLM # Import here to avoid circular references
|
||||
|
||||
# Make sure it's an instance of the LLM class:
|
||||
if isinstance(final_llm, LLM):
|
||||
print(json.dumps(final_llm.to_dict()))
|
||||
else:
|
||||
# If somehow it's not an LLM, try to interpret as a dict
|
||||
# or revert to an empty fallback
|
||||
if isinstance(final_llm, dict):
|
||||
print(json.dumps(final_llm))
|
||||
else:
|
||||
print(json.dumps({}))
|
||||
else:
|
||||
print(json.dumps({}))
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while fetching chat LLM: {e}")
|
||||
|
||||
@@ -14,6 +14,8 @@ run_crew = "{{folder_name}}.main:run"
|
||||
train = "{{folder_name}}.main:train"
|
||||
replay = "{{folder_name}}.main:replay"
|
||||
test = "{{folder_name}}.main:test"
|
||||
fetch_inputs = "{{folder_name}}.main:fetch_inputs"
|
||||
fetch_chat_llm = "{{folder_name}}.main:fetch_chat_llm"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -202,6 +203,10 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||
default=None,
|
||||
)
|
||||
@@ -957,6 +962,31 @@ class Crew(BaseModel):
|
||||
return self._knowledge.query(query)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> Set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
'role', 'goal', and 'backstory'.
|
||||
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
required_inputs: Set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
# Scan agents for inputs
|
||||
for agent in self.agents:
|
||||
# role, goal, backstory might have placeholders like {role_detail}, etc.
|
||||
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
return required_inputs
|
||||
|
||||
def copy(self):
|
||||
"""Create a deep copy of the Crew."""
|
||||
|
||||
|
||||
@@ -140,6 +140,65 @@ class LLM:
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Return a dict of all relevant parameters for serialization.
|
||||
"""
|
||||
return {
|
||||
"model": self.model,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_completion_tokens": self.max_completion_tokens,
|
||||
"max_tokens": self.max_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"base_url": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"callbacks": self.callbacks,
|
||||
"kwargs": self.kwargs,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "LLM":
|
||||
"""
|
||||
Create an LLM instance from a dict.
|
||||
We assume the dict has all relevant keys that match what's in the constructor.
|
||||
"""
|
||||
# We can pop off fields we know, then pass the rest into **kwargs
|
||||
# so that any leftover keys still get passed into the LLM constructor.
|
||||
known_fields = {}
|
||||
known_fields["model"] = data.pop("model", None)
|
||||
known_fields["timeout"] = data.pop("timeout", None)
|
||||
known_fields["temperature"] = data.pop("temperature", None)
|
||||
known_fields["top_p"] = data.pop("top_p", None)
|
||||
known_fields["n"] = data.pop("n", None)
|
||||
known_fields["stop"] = data.pop("stop", None)
|
||||
known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None)
|
||||
known_fields["max_tokens"] = data.pop("max_tokens", None)
|
||||
known_fields["presence_penalty"] = data.pop("presence_penalty", None)
|
||||
known_fields["frequency_penalty"] = data.pop("frequency_penalty", None)
|
||||
known_fields["logit_bias"] = data.pop("logit_bias", None)
|
||||
known_fields["response_format"] = data.pop("response_format", None)
|
||||
known_fields["seed"] = data.pop("seed", None)
|
||||
known_fields["logprobs"] = data.pop("logprobs", None)
|
||||
known_fields["top_logprobs"] = data.pop("top_logprobs", None)
|
||||
known_fields["base_url"] = data.pop("base_url", None)
|
||||
known_fields["api_version"] = data.pop("api_version", None)
|
||||
known_fields["api_key"] = data.pop("api_key", None)
|
||||
known_fields["callbacks"] = data.pop("callbacks", None)
|
||||
|
||||
# leftover keys go into kwargs:
|
||||
return cls(**known_fields, **data)
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
|
||||
166
src/crewai/utilities/llm_utils.py
Normal file
166
src/crewai/utilities/llm_utils.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
default_model: str = "gpt-4o-mini",
|
||||
) -> Optional[LLM]:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
|
||||
Args:
|
||||
llm_value (str | LLM | Any | None):
|
||||
- str: The model name (e.g., "gpt-4").
|
||||
- LLM: Already instantiated LLM, returned as-is.
|
||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||
- None: Use environment-based or fallback default model.
|
||||
default_model (str): The fallback model name to use if llm_value is None
|
||||
and no environment variable is set.
|
||||
|
||||
Returns:
|
||||
An LLM instance if successful, or None if something fails.
|
||||
"""
|
||||
|
||||
# 1) If llm_value is already an LLM object, return it directly
|
||||
if isinstance(llm_value, LLM):
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
created_llm = LLM(model=llm_value)
|
||||
print(f"LLM created with model='{llm_value}'")
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
return None
|
||||
|
||||
# 3) If llm_value is None, parse environment variables or use default
|
||||
if llm_value is None:
|
||||
return _llm_via_environment_or_fallback(default_model)
|
||||
|
||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object (like a config)
|
||||
# e.g. follow the approach used in agent.py
|
||||
try:
|
||||
llm_params = {
|
||||
"model": (
|
||||
getattr(llm_value, "model_name", None)
|
||||
or getattr(llm_value, "deployment_name", None)
|
||||
or str(llm_value)
|
||||
),
|
||||
"temperature": getattr(llm_value, "temperature", None),
|
||||
"max_tokens": getattr(llm_value, "max_tokens", None),
|
||||
"logprobs": getattr(llm_value, "logprobs", None),
|
||||
"timeout": getattr(llm_value, "timeout", None),
|
||||
"max_retries": getattr(llm_value, "max_retries", None),
|
||||
"api_key": getattr(llm_value, "api_key", None),
|
||||
"base_url": getattr(llm_value, "base_url", None),
|
||||
"organization": getattr(llm_value, "organization", None),
|
||||
}
|
||||
# Remove None values
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
created_llm = LLM(**llm_params)
|
||||
print(
|
||||
"LLM created with extracted parameters; "
|
||||
f"model='{llm_params.get('model', 'UNKNOWN')}'"
|
||||
)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Error instantiating LLM from unknown object type: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def create_chat_llm(default_model: str = "gpt-4") -> Optional[LLM]:
|
||||
"""
|
||||
Creates a Chat LLM with additional checks, such as verifying crewAI version
|
||||
or reading from pyproject.toml. Then calls `create_llm(None, default_model)`.
|
||||
|
||||
Args:
|
||||
default_model (str): Fallback model if not set in environment.
|
||||
|
||||
Returns:
|
||||
An instance of LLM or None if instantiation fails.
|
||||
"""
|
||||
print("[create_chat_llm] Checking environment and version info...")
|
||||
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.87.0" # Update to latest if needed
|
||||
|
||||
pyproject_data = read_toml()
|
||||
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||
version.parse(crewai_version) < version.parse(min_required_version)
|
||||
):
|
||||
print(
|
||||
f"You are running an older version of crewAI ({crewai_version}) that uses poetry.\n"
|
||||
"Please run `crewai update` to switch to uv-based builds."
|
||||
)
|
||||
|
||||
# After checks, simply call create_llm with None (meaning "use env or fallback"):
|
||||
return create_llm(None, default_model=default_model)
|
||||
|
||||
|
||||
def _llm_via_environment_or_fallback(
|
||||
default_model: str = "gpt-4o-mini",
|
||||
) -> Optional[LLM]:
|
||||
"""
|
||||
Helper function: if llm_value is None, we load environment variables or fallback default model.
|
||||
"""
|
||||
model_name = (
|
||||
os.environ.get("OPENAI_MODEL_NAME") or os.environ.get("MODEL") or default_model
|
||||
)
|
||||
llm_params = {"model": model_name}
|
||||
|
||||
# Optional base URL from env
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
|
||||
if api_base:
|
||||
llm_params["base_url"] = api_base
|
||||
|
||||
UNACCEPTED_ATTRIBUTES = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
|
||||
|
||||
if set_provider in ENV_VARS:
|
||||
for env_var in ENV_VARS[set_provider]:
|
||||
key_name = env_var.get("key_name")
|
||||
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
|
||||
env_value = os.environ.get(key_name)
|
||||
if env_value:
|
||||
# Map environment variable names to recognized LITELLM_PARAMS if any
|
||||
param_key = _normalize_key_name(key_name.lower())
|
||||
llm_params[param_key] = env_value
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
|
||||
# Try creating the LLM
|
||||
try:
|
||||
new_llm = LLM(**llm_params)
|
||||
print(f"LLM created with model='{model_name}'")
|
||||
return new_llm
|
||||
except Exception as e:
|
||||
print(f"Error instantiating LLM from environment/fallback: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_key_name(key_name: str) -> str:
|
||||
"""
|
||||
Maps environment variable names to recognized litellm parameter keys,
|
||||
using patterns from LITELLM_PARAMS.
|
||||
"""
|
||||
for pattern in LITELLM_PARAMS:
|
||||
if pattern in key_name:
|
||||
return pattern
|
||||
return key_name
|
||||
@@ -2583,3 +2583,26 @@ def test_hierarchical_verbose_false_manager_agent():
|
||||
|
||||
assert crew.manager_agent is not None
|
||||
assert not crew.manager_agent.verbose
|
||||
|
||||
|
||||
def test_fetch_inputs():
|
||||
agent = Agent(
|
||||
role="{role_detail} Researcher",
|
||||
goal="Research on {topic}.",
|
||||
backstory="Expert in {field}.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data on {topic}.",
|
||||
expected_output="Summary of {topic} analysis.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
expected_placeholders = {"role_detail", "topic", "field"}
|
||||
actual_placeholders = crew.fetch_inputs()
|
||||
|
||||
assert (
|
||||
actual_placeholders == expected_placeholders
|
||||
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||
|
||||
Reference in New Issue
Block a user