mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 16:18:13 +00:00
Compare commits
13 Commits
fix/embedd
...
0aae59dc1d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0aae59dc1d | ||
|
|
b97c4cf2c6 | ||
|
|
4a794622c7 | ||
|
|
104e8bc167 | ||
|
|
9e5d4972b9 | ||
|
|
3ba15e8bc9 | ||
|
|
0284095ff8 | ||
|
|
bcd838a2ff | ||
|
|
5da6d36dd9 | ||
|
|
0e7aa192c0 | ||
|
|
2f882d68ad | ||
|
|
2bf5b15f1e | ||
|
|
1c45f730c6 |
@@ -21,6 +21,7 @@ from crewai.tools.base_tool import Tool
|
|||||||
from crewai.utilities import Converter, Prompts
|
from crewai.utilities import Converter, Prompts
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
from crewai.utilities.llm_utils import create_llm
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
@@ -139,89 +140,9 @@ class Agent(BaseAgent):
|
|||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
self._set_knowledge()
|
self._set_knowledge()
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
unaccepted_attributes = [
|
|
||||||
"AWS_ACCESS_KEY_ID",
|
|
||||||
"AWS_SECRET_ACCESS_KEY",
|
|
||||||
"AWS_REGION_NAME",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Handle different cases for self.llm
|
self.llm = create_llm(self.llm)
|
||||||
if isinstance(self.llm, str):
|
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||||
# If it's a string, create an LLM instance
|
|
||||||
self.llm = LLM(model=self.llm)
|
|
||||||
elif isinstance(self.llm, LLM):
|
|
||||||
# If it's already an LLM instance, keep it as is
|
|
||||||
pass
|
|
||||||
elif self.llm is None:
|
|
||||||
# Determine the model name from environment variables or use default
|
|
||||||
model_name = (
|
|
||||||
os.environ.get("OPENAI_MODEL_NAME")
|
|
||||||
or os.environ.get("MODEL")
|
|
||||||
or "gpt-4o-mini"
|
|
||||||
)
|
|
||||||
llm_params = {"model": model_name}
|
|
||||||
|
|
||||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
|
|
||||||
"OPENAI_BASE_URL"
|
|
||||||
)
|
|
||||||
if api_base:
|
|
||||||
llm_params["base_url"] = api_base
|
|
||||||
|
|
||||||
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
|
|
||||||
|
|
||||||
# Iterate over all environment variables to find matching API keys or use defaults
|
|
||||||
for provider, env_vars in ENV_VARS.items():
|
|
||||||
if provider == set_provider:
|
|
||||||
for env_var in env_vars:
|
|
||||||
# Check if the environment variable is set
|
|
||||||
key_name = env_var.get("key_name")
|
|
||||||
if key_name and key_name not in unaccepted_attributes:
|
|
||||||
env_value = os.environ.get(key_name)
|
|
||||||
if env_value:
|
|
||||||
key_name = key_name.lower()
|
|
||||||
for pattern in LITELLM_PARAMS:
|
|
||||||
if pattern in key_name:
|
|
||||||
key_name = pattern
|
|
||||||
break
|
|
||||||
llm_params[key_name] = env_value
|
|
||||||
# Check for default values if the environment variable is not set
|
|
||||||
elif env_var.get("default", False):
|
|
||||||
for key, value in env_var.items():
|
|
||||||
if key not in ["prompt", "key_name", "default"]:
|
|
||||||
# Only add default if the key is already set in os.environ
|
|
||||||
if key in os.environ:
|
|
||||||
llm_params[key] = value
|
|
||||||
|
|
||||||
self.llm = LLM(**llm_params)
|
|
||||||
else:
|
|
||||||
# For any other type, attempt to extract relevant attributes
|
|
||||||
llm_params = {
|
|
||||||
"model": getattr(self.llm, "model_name", None)
|
|
||||||
or getattr(self.llm, "deployment_name", None)
|
|
||||||
or str(self.llm),
|
|
||||||
"temperature": getattr(self.llm, "temperature", None),
|
|
||||||
"max_tokens": getattr(self.llm, "max_tokens", None),
|
|
||||||
"logprobs": getattr(self.llm, "logprobs", None),
|
|
||||||
"timeout": getattr(self.llm, "timeout", None),
|
|
||||||
"max_retries": getattr(self.llm, "max_retries", None),
|
|
||||||
"api_key": getattr(self.llm, "api_key", None),
|
|
||||||
"base_url": getattr(self.llm, "base_url", None),
|
|
||||||
"organization": getattr(self.llm, "organization", None),
|
|
||||||
}
|
|
||||||
# Remove None values to avoid passing unnecessary parameters
|
|
||||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
|
||||||
self.llm = LLM(**llm_params)
|
|
||||||
|
|
||||||
# Similar handling for function_calling_llm
|
|
||||||
if self.function_calling_llm:
|
|
||||||
if isinstance(self.function_calling_llm, str):
|
|
||||||
self.function_calling_llm = LLM(model=self.function_calling_llm)
|
|
||||||
elif not isinstance(self.function_calling_llm, LLM):
|
|
||||||
self.function_calling_llm = LLM(
|
|
||||||
model=getattr(self.function_calling_llm, "model_name", None)
|
|
||||||
or getattr(self.function_calling_llm, "deployment_name", None)
|
|
||||||
or str(self.function_calling_llm)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
self._setup_agent_executor()
|
self._setup_agent_executor()
|
||||||
@@ -272,6 +193,8 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
task_prompt = task.prompt()
|
task_prompt = task.prompt()
|
||||||
|
|
||||||
|
print("task_prompt:", task_prompt)
|
||||||
|
|
||||||
# If the task requires output in JSON or Pydantic format,
|
# If the task requires output in JSON or Pydantic format,
|
||||||
# append specific instructions to the task prompt to ensure
|
# append specific instructions to the task prompt to ensure
|
||||||
# that the final answer does not include any code block markers
|
# that the final answer does not include any code block markers
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
import os
|
||||||
from importlib.metadata import version as get_version
|
from importlib.metadata import version as get_version
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||||
from crewai.cli.create_crew import create_crew
|
from crewai.cli.create_crew import create_crew
|
||||||
from crewai.cli.create_flow import create_flow
|
from crewai.cli.create_flow import create_flow
|
||||||
|
from crewai.cli.crew_chat import run_chat
|
||||||
|
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
||||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||||
KickoffTaskOutputsSQLiteStorage,
|
KickoffTaskOutputsSQLiteStorage,
|
||||||
)
|
)
|
||||||
@@ -342,5 +345,15 @@ def flow_add_crew(crew_name):
|
|||||||
add_crew_to_flow(crew_name)
|
add_crew_to_flow(crew_name)
|
||||||
|
|
||||||
|
|
||||||
|
@crewai.command()
|
||||||
|
def chat():
|
||||||
|
"""
|
||||||
|
Start a conversation with the Crew, collecting user-supplied inputs,
|
||||||
|
and using the Chat LLM to generate responses.
|
||||||
|
"""
|
||||||
|
click.echo("Starting a conversation with the Crew")
|
||||||
|
run_chat()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
crewai()
|
crewai()
|
||||||
|
|||||||
@@ -158,6 +158,8 @@ MODELS = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFAULT_LLM_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
386
src/crewai/cli/crew_chat.py
Normal file
386
src/crewai/cli/crew_chat.py
Normal file
@@ -0,0 +1,386 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Set, Tuple
|
||||||
|
|
||||||
|
import click
|
||||||
|
import tomli
|
||||||
|
|
||||||
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
|
from crewai.cli.fetch_chat_llm import fetch_chat_llm
|
||||||
|
from crewai.cli.fetch_crew_inputs import fetch_crew_inputs
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||||
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
|
|
||||||
|
def run_chat():
|
||||||
|
"""
|
||||||
|
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||||
|
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||||
|
Exits if crew_name or crew_description are missing.
|
||||||
|
"""
|
||||||
|
crew, crew_name = load_crew_and_name()
|
||||||
|
click.secho("\nFetching the Chat LLM...", fg="cyan")
|
||||||
|
try:
|
||||||
|
chat_llm = create_llm(crew.chat_llm)
|
||||||
|
except Exception as e:
|
||||||
|
click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red")
|
||||||
|
return
|
||||||
|
if not chat_llm:
|
||||||
|
click.secho("No valid Chat LLM returned. Exiting.", fg="red")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate crew chat inputs automatically
|
||||||
|
crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm)
|
||||||
|
print("crew_inputs:", crew_chat_inputs)
|
||||||
|
|
||||||
|
# Generate a tool schema from the crew inputs
|
||||||
|
crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs)
|
||||||
|
print("crew_tool_schema:", crew_tool_schema)
|
||||||
|
|
||||||
|
# Build initial system message
|
||||||
|
required_fields_str = (
|
||||||
|
", ".join(
|
||||||
|
f"{field.name} (desc: {field.description or 'n/a'})"
|
||||||
|
for field in crew_chat_inputs.inputs
|
||||||
|
)
|
||||||
|
or "(No required fields detected)"
|
||||||
|
)
|
||||||
|
|
||||||
|
system_message = (
|
||||||
|
"You are a helpful AI assistant for the CrewAI platform. "
|
||||||
|
"Your primary purpose is to assist users with the crew's specific tasks. "
|
||||||
|
"You can answer general questions, but should guide users back to the crew's purpose afterward. "
|
||||||
|
"For example, after answering a general question, remind the user of your main purpose, such as generating a research report, and prompt them to specify a topic or task related to the crew's purpose. "
|
||||||
|
"You have a function (tool) you can call by name if you have all required inputs. "
|
||||||
|
f"Those required inputs are: {required_fields_str}. "
|
||||||
|
"Once you have them, call the function. "
|
||||||
|
"Please keep your responses concise and friendly. "
|
||||||
|
"If a user asks a question outside the crew's scope, provide a brief answer and remind them of the crew's purpose. "
|
||||||
|
"After calling the tool, be prepared to take user feedback and make adjustments as needed. "
|
||||||
|
"If you are ever unsure about a user's request or need clarification, ask the user for more information."
|
||||||
|
f"\nCrew Name: {crew_chat_inputs.crew_name}"
|
||||||
|
f"\nCrew Description: {crew_chat_inputs.crew_description}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_message},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a wrapper function that captures 'crew' and 'messages' from the enclosing scope
|
||||||
|
def run_crew_tool_with_messages(**kwargs):
|
||||||
|
return run_crew_tool(crew, messages, **kwargs)
|
||||||
|
|
||||||
|
# Prepare available_functions with the wrapper function
|
||||||
|
available_functions = {
|
||||||
|
crew_chat_inputs.crew_name: run_crew_tool_with_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
click.secho(
|
||||||
|
"\nEntering an interactive chat loop with function-calling.\n"
|
||||||
|
"Type 'exit' or Ctrl+C to quit.\n",
|
||||||
|
fg="cyan",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main chat loop
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = click.prompt("You", type=str)
|
||||||
|
if user_input.strip().lower() in ["exit", "quit"]:
|
||||||
|
click.echo("Exiting chat. Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Append user message
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
# Invoke the LLM, passing tools and available_functions
|
||||||
|
final_response = chat_llm.call(
|
||||||
|
messages=messages,
|
||||||
|
tools=[crew_tool_schema],
|
||||||
|
available_functions=available_functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append assistant's reply
|
||||||
|
messages.append({"role": "assistant", "content": final_response})
|
||||||
|
|
||||||
|
# Display assistant's reply
|
||||||
|
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
click.echo("\nExiting chat. Goodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
click.secho(f"An error occurred: {e}", fg="red")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||||
|
"""
|
||||||
|
Dynamically build a Littellm 'function' schema for the given crew.
|
||||||
|
|
||||||
|
crew_name: The name of the crew (used for the function 'name').
|
||||||
|
crew_inputs: A ChatInputs object containing crew_description
|
||||||
|
and a list of input fields (each with a name & description).
|
||||||
|
"""
|
||||||
|
properties = {}
|
||||||
|
for field in crew_inputs.inputs:
|
||||||
|
properties[field.name] = {
|
||||||
|
"type": "string",
|
||||||
|
"description": field.description or "No description provided",
|
||||||
|
}
|
||||||
|
|
||||||
|
required_fields = [field.name for field in crew_inputs.inputs]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": crew_inputs.crew_name,
|
||||||
|
"description": crew_inputs.crew_description or "No crew description",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required_fields,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs):
|
||||||
|
"""
|
||||||
|
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crew (Crew): The crew instance to run.
|
||||||
|
messages (List[Dict[str, str]]): The chat messages up to this point.
|
||||||
|
**kwargs: The inputs collected from the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The output from the crew's execution.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SystemExit: Exits the chat if an error occurs during crew execution.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Serialize 'messages' to JSON string before adding to kwargs
|
||||||
|
kwargs['crew_chat_messages'] = json.dumps(messages)
|
||||||
|
|
||||||
|
# Run the crew with the provided inputs
|
||||||
|
crew_output = crew.kickoff(inputs=kwargs)
|
||||||
|
|
||||||
|
# Convert CrewOutput to a string to send back to the user
|
||||||
|
result = str(crew_output)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
# Exit the chat and show the error message
|
||||||
|
click.secho("An error occurred while running the crew:", fg="red")
|
||||||
|
click.secho(str(e), fg="red")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def load_crew_and_name() -> Tuple[Crew, str]:
|
||||||
|
"""
|
||||||
|
Loads the crew by importing the crew class from the user's project.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew.
|
||||||
|
"""
|
||||||
|
# Get the current working directory
|
||||||
|
cwd = Path.cwd()
|
||||||
|
|
||||||
|
# Path to the pyproject.toml file
|
||||||
|
pyproject_path = cwd / "pyproject.toml"
|
||||||
|
if not pyproject_path.exists():
|
||||||
|
raise FileNotFoundError("pyproject.toml not found in the current directory.")
|
||||||
|
|
||||||
|
# Load the pyproject.toml file using 'tomli'
|
||||||
|
with pyproject_path.open("rb") as f:
|
||||||
|
pyproject_data = tomli.load(f)
|
||||||
|
|
||||||
|
# Get the project name from the 'project' section
|
||||||
|
project_name = pyproject_data["project"]["name"]
|
||||||
|
folder_name = project_name
|
||||||
|
|
||||||
|
# Derive the crew class name from the project name
|
||||||
|
# E.g., if project_name is 'my_project', crew_class_name is 'MyProject'
|
||||||
|
crew_class_name = project_name.replace("_", " ").title().replace(" ", "")
|
||||||
|
|
||||||
|
# Add the 'src' directory to sys.path
|
||||||
|
src_path = cwd / "src"
|
||||||
|
if str(src_path) not in sys.path:
|
||||||
|
sys.path.insert(0, str(src_path))
|
||||||
|
|
||||||
|
# Import the crew module
|
||||||
|
crew_module_name = f"{folder_name}.crew"
|
||||||
|
try:
|
||||||
|
crew_module = __import__(crew_module_name, fromlist=[crew_class_name])
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Failed to import crew module {crew_module_name}: {e}")
|
||||||
|
|
||||||
|
# Get the crew class from the module
|
||||||
|
try:
|
||||||
|
crew_class = getattr(crew_module, crew_class_name)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Crew class {crew_class_name} not found in module {crew_module_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiate the crew
|
||||||
|
crew_instance = crew_class().crew()
|
||||||
|
return crew_instance, crew_class_name
|
||||||
|
|
||||||
|
|
||||||
|
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
|
||||||
|
"""
|
||||||
|
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crew (Crew): The crew object containing tasks and agents.
|
||||||
|
crew_name (str): The name of the crew.
|
||||||
|
chat_llm: The chat language model to use for AI calls.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatInputs: An object containing the crew's name, description, and input fields.
|
||||||
|
"""
|
||||||
|
# Extract placeholders from tasks and agents
|
||||||
|
required_inputs = fetch_required_inputs(crew)
|
||||||
|
|
||||||
|
# Generate descriptions for each input using AI
|
||||||
|
input_fields = []
|
||||||
|
for input_name in required_inputs:
|
||||||
|
description = generate_input_description_with_ai(input_name, crew, chat_llm)
|
||||||
|
input_fields.append(ChatInputField(name=input_name, description=description))
|
||||||
|
|
||||||
|
# Generate crew description using AI
|
||||||
|
crew_description = generate_crew_description_with_ai(crew, chat_llm)
|
||||||
|
|
||||||
|
return ChatInputs(
|
||||||
|
crew_name=crew_name,
|
||||||
|
crew_description=crew_description,
|
||||||
|
inputs=input_fields
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_required_inputs(crew: Crew) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Extracts placeholders from the crew's tasks and agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crew (Crew): The crew object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set[str]: A set of placeholder names.
|
||||||
|
"""
|
||||||
|
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||||
|
required_inputs: Set[str] = set()
|
||||||
|
|
||||||
|
# Scan tasks
|
||||||
|
for task in crew.tasks:
|
||||||
|
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||||
|
required_inputs.update(placeholder_pattern.findall(text))
|
||||||
|
|
||||||
|
# Scan agents
|
||||||
|
for agent in crew.agents:
|
||||||
|
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||||
|
required_inputs.update(placeholder_pattern.findall(text))
|
||||||
|
|
||||||
|
return required_inputs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
|
||||||
|
"""
|
||||||
|
Generates an input description using AI based on the context of the crew.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_name (str): The name of the input placeholder.
|
||||||
|
crew (Crew): The crew object.
|
||||||
|
chat_llm: The chat language model to use for AI calls.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A concise description of the input.
|
||||||
|
"""
|
||||||
|
# Gather context from tasks and agents where the input is used
|
||||||
|
context_texts = []
|
||||||
|
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||||
|
|
||||||
|
for task in crew.tasks:
|
||||||
|
if f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output:
|
||||||
|
# Replace placeholders with input names
|
||||||
|
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
|
||||||
|
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
|
||||||
|
context_texts.append(f"Task Description: {task_description}")
|
||||||
|
context_texts.append(f"Expected Output: {expected_output}")
|
||||||
|
for agent in crew.agents:
|
||||||
|
if f"{{{input_name}}}" in agent.role or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory:
|
||||||
|
# Replace placeholders with input names
|
||||||
|
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
||||||
|
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
||||||
|
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory)
|
||||||
|
context_texts.append(f"Agent Role: {agent_role}")
|
||||||
|
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||||
|
context_texts.append(f"Agent Backstory: {agent_backstory}")
|
||||||
|
|
||||||
|
context = "\n".join(context_texts)
|
||||||
|
if not context:
|
||||||
|
# If no context is found for the input, raise an exception as per instruction
|
||||||
|
raise ValueError(f"No context found for input '{input_name}'.")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"Based on the following context, write a concise description (15 words or less) of the input '{input_name}'.\n"
|
||||||
|
"Provide only the description, without any extra text or labels. Do not include placeholders like '{topic}' in the description.\n"
|
||||||
|
"Context:\n"
|
||||||
|
f"{context}"
|
||||||
|
)
|
||||||
|
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||||
|
description = response.strip()
|
||||||
|
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||||
|
"""
|
||||||
|
Generates a brief description of the crew using AI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crew (Crew): The crew object.
|
||||||
|
chat_llm: The chat language model to use for AI calls.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A concise description of the crew's purpose (15 words or less).
|
||||||
|
"""
|
||||||
|
# Gather context from tasks and agents
|
||||||
|
context_texts = []
|
||||||
|
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||||
|
|
||||||
|
for task in crew.tasks:
|
||||||
|
# Replace placeholders with input names
|
||||||
|
task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description)
|
||||||
|
expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output)
|
||||||
|
context_texts.append(f"Task Description: {task_description}")
|
||||||
|
context_texts.append(f"Expected Output: {expected_output}")
|
||||||
|
for agent in crew.agents:
|
||||||
|
# Replace placeholders with input names
|
||||||
|
agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role)
|
||||||
|
agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal)
|
||||||
|
agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory)
|
||||||
|
context_texts.append(f"Agent Role: {agent_role}")
|
||||||
|
context_texts.append(f"Agent Goal: {agent_goal}")
|
||||||
|
context_texts.append(f"Agent Backstory: {agent_backstory}")
|
||||||
|
|
||||||
|
context = "\n".join(context_texts)
|
||||||
|
if not context:
|
||||||
|
raise ValueError("No context found for generating crew description.")
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
"Based on the following context, write a concise, action-oriented description (15 words or less) of the crew's purpose.\n"
|
||||||
|
"Provide only the description, without any extra text or labels. Do not include placeholders like '{topic}' in the description.\n"
|
||||||
|
"Context:\n"
|
||||||
|
f"{context}"
|
||||||
|
)
|
||||||
|
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||||
|
crew_description = response.strip()
|
||||||
|
|
||||||
|
return crew_description
|
||||||
81
src/crewai/cli/fetch_chat_llm.py
Normal file
81
src/crewai/cli/fetch_chat_llm.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import click
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from crewai.cli.utils import read_toml
|
||||||
|
from crewai.cli.version import get_crewai_version
|
||||||
|
from crewai.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_chat_llm() -> LLM:
|
||||||
|
"""
|
||||||
|
Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name),
|
||||||
|
parsing its JSON stdout, and returning an LLM instance.
|
||||||
|
|
||||||
|
This expects the script "fetch_chat_llm" to print out JSON that represents the
|
||||||
|
LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))).
|
||||||
|
|
||||||
|
Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# You may change this command to match whatever's in your pyproject.toml [project.scripts].
|
||||||
|
command = ["uv", "run", "fetch_chat_llm"]
|
||||||
|
|
||||||
|
crewai_version = get_crewai_version()
|
||||||
|
min_required_version = "0.87.0" # Adjust as needed
|
||||||
|
|
||||||
|
pyproject_data = read_toml()
|
||||||
|
|
||||||
|
# If old poetry-based setup is detected and version is below min_required_version
|
||||||
|
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||||
|
version.parse(crewai_version) < version.parse(min_required_version)
|
||||||
|
):
|
||||||
|
click.secho(
|
||||||
|
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
|
||||||
|
f"Please run `crewai update` to transition your pyproject.toml to use uv.",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize a reference to your LLM
|
||||||
|
llm_instance = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||||
|
stdout_lines = result.stdout.strip().splitlines()
|
||||||
|
|
||||||
|
# Find the line that contains the JSON data
|
||||||
|
json_line = next(
|
||||||
|
(
|
||||||
|
line
|
||||||
|
for line in stdout_lines
|
||||||
|
if line.startswith("{") and line.endswith("}")
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not json_line:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No valid JSON output received from `fetch_chat_llm` command."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_data = json.loads(json_line)
|
||||||
|
llm_instance = LLM.from_dict(llm_data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"An unexpected error occurred while fetching chat LLM: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not llm_instance:
|
||||||
|
raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.")
|
||||||
|
|
||||||
|
return llm_instance
|
||||||
86
src/crewai/cli/fetch_crew_inputs.py
Normal file
86
src/crewai/cli/fetch_crew_inputs.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from crewai.cli.utils import read_toml
|
||||||
|
from crewai.cli.version import get_crewai_version
|
||||||
|
from crewai.types.crew_chat import ChatInputs
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_crew_inputs() -> Optional[ChatInputs]:
|
||||||
|
"""
|
||||||
|
Fetch the crew's ChatInputs (a structure containing crew_description and input fields)
|
||||||
|
by running "uv run fetch_chat_inputs", which prints JSON representing a ChatInputs object.
|
||||||
|
|
||||||
|
This function will parse that JSON and return a ChatInputs instance.
|
||||||
|
If the output is empty or invalid, an empty ChatInputs object is returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
command = ["uv", "run", "fetch_chat_inputs"]
|
||||||
|
crewai_version = get_crewai_version()
|
||||||
|
min_required_version = "0.87.0"
|
||||||
|
|
||||||
|
pyproject_data = read_toml()
|
||||||
|
crew_name = pyproject_data.get("project", {}).get("name", None)
|
||||||
|
|
||||||
|
# If you're on an older poetry-based setup and version < min_required_version
|
||||||
|
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||||
|
version.parse(crewai_version) < version.parse(min_required_version)
|
||||||
|
):
|
||||||
|
click.secho(
|
||||||
|
f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n"
|
||||||
|
f"Please run `crewai update` to update your pyproject.toml to use uv.",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
||||||
|
stdout_lines = result.stdout.strip().splitlines()
|
||||||
|
|
||||||
|
# Find the line that contains the JSON data
|
||||||
|
json_line = next(
|
||||||
|
(
|
||||||
|
line
|
||||||
|
for line in stdout_lines
|
||||||
|
if line.startswith("{") and line.endswith("}")
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not json_line:
|
||||||
|
click.echo(
|
||||||
|
"No valid JSON output received from `fetch_chat_inputs` command.",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_data = json.loads(json_line)
|
||||||
|
chat_inputs = ChatInputs(**raw_data)
|
||||||
|
if crew_name:
|
||||||
|
chat_inputs.crew_name = crew_name
|
||||||
|
return chat_inputs
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
click.echo(
|
||||||
|
f"Unable to parse JSON from `fetch_chat_inputs` output: {e}\nOutput: {repr(json_line)}",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
click.echo(f"An error occurred while fetching chat inputs: {e}", err=True)
|
||||||
|
click.echo(e.output, err=True, nl=True)
|
||||||
|
|
||||||
|
if pyproject_data.get("tool", {}).get("poetry"):
|
||||||
|
click.secho(
|
||||||
|
"It's possible that you are using an old version of crewAI that uses poetry.\n"
|
||||||
|
"Please run `crewai update` to update your pyproject.toml to use uv.",
|
||||||
|
fg="yellow",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import sys
|
import sys
|
||||||
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from {{folder_name}}.crew import {{crew_name}}
|
from {{folder_name}}.crew import {{crew_name}}
|
||||||
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||||
|
|
||||||
@@ -13,12 +15,30 @@ warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
|||||||
|
|
||||||
def run():
|
def run():
|
||||||
"""
|
"""
|
||||||
Run the crew.
|
Run the crew, allowing CLI overrides for required inputs.
|
||||||
|
Usage example:
|
||||||
|
uv run run_crew -- --topic="New Topic" --some_other_field="Value"
|
||||||
"""
|
"""
|
||||||
|
# Default inputs
|
||||||
inputs = {
|
inputs = {
|
||||||
'topic': 'AI LLMs'
|
'topic': 'AI LLMs'
|
||||||
|
# Add any other default fields here
|
||||||
}
|
}
|
||||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
|
||||||
|
# 1) Gather overrides from sys.argv
|
||||||
|
# sys.argv might look like: ['run_crew', '--topic=NewTopic']
|
||||||
|
# But be aware that if you're calling "uv run run_crew", sys.argv might have
|
||||||
|
# additional items. So we typically skip the first 1 or 2 items to get only overrides.
|
||||||
|
overrides = parse_cli_overrides(sys.argv[1:])
|
||||||
|
|
||||||
|
# 2) Merge the overrides into defaults
|
||||||
|
inputs.update(overrides)
|
||||||
|
|
||||||
|
# 3) Kick off the crew with final inputs
|
||||||
|
try:
|
||||||
|
{{crew_name}}().crew().kickoff(inputs=inputs)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while running the crew: {e}")
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
@@ -55,4 +75,94 @@ def test():
|
|||||||
{{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs)
|
{{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while replaying the crew: {e}")
|
raise Exception(f"An error occurred while testing the crew: {e}")
|
||||||
|
|
||||||
|
def fetch_inputs():
|
||||||
|
"""
|
||||||
|
Command that gathers required placeholders/inputs from the Crew, then
|
||||||
|
prints them as JSON to stdout so external scripts can parse them easily.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
crew = {{crew_name}}().crew()
|
||||||
|
crew_inputs = crew.fetch_inputs()
|
||||||
|
json_string = json.dumps(list(crew_inputs))
|
||||||
|
print(json_string)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while fetching inputs: {e}")
|
||||||
|
|
||||||
|
def fetch_chat_llm():
|
||||||
|
"""
|
||||||
|
Command that fetches the 'chat_llm' property from the Crew,
|
||||||
|
instantiates it via create_llm(),
|
||||||
|
and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
crew = {{crew_name}}().crew()
|
||||||
|
raw_chat_llm = getattr(crew, "chat_llm", None)
|
||||||
|
|
||||||
|
if not raw_chat_llm:
|
||||||
|
# If the crew doesn't have chat_llm, fallback to create_llm(None)
|
||||||
|
final_llm = create_llm(None)
|
||||||
|
else:
|
||||||
|
# raw_chat_llm might be a dict, or an LLM, or something else
|
||||||
|
final_llm = create_llm(raw_chat_llm)
|
||||||
|
|
||||||
|
if final_llm:
|
||||||
|
# Print the final LLM as JSON, so fetch_chat_llm.py can parse it
|
||||||
|
from crewai.llm import LLM # Import here to avoid circular references
|
||||||
|
|
||||||
|
# Make sure it's an instance of the LLM class:
|
||||||
|
if isinstance(final_llm, LLM):
|
||||||
|
print(json.dumps(final_llm.to_dict()))
|
||||||
|
else:
|
||||||
|
# If somehow it's not an LLM, try to interpret as a dict
|
||||||
|
# or revert to an empty fallback
|
||||||
|
if isinstance(final_llm, dict):
|
||||||
|
print(json.dumps(final_llm))
|
||||||
|
else:
|
||||||
|
print(json.dumps({}))
|
||||||
|
else:
|
||||||
|
print(json.dumps({}))
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while fetching chat LLM: {e}")
|
||||||
|
|
||||||
|
# TODO: Talk to Joao about making using LLM calls to analyze the crew
|
||||||
|
# and generate all of this information automatically
|
||||||
|
def fetch_chat_inputs():
|
||||||
|
"""
|
||||||
|
Command that fetches the 'chat_inputs' property from the Crew,
|
||||||
|
and prints it as JSON to stdout.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
crew = {{crew_name}}().crew()
|
||||||
|
raw_chat_inputs = getattr(crew, "chat_inputs", None)
|
||||||
|
|
||||||
|
if raw_chat_inputs:
|
||||||
|
# Convert to dictionary to print JSON
|
||||||
|
print(json.dumps(raw_chat_inputs.model_dump()))
|
||||||
|
else:
|
||||||
|
# If crew.chat_inputs is None or empty, print an empty JSON
|
||||||
|
print(json.dumps({}))
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while fetching chat inputs: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cli_overrides(args_list) -> dict:
|
||||||
|
"""
|
||||||
|
Parse arguments in the form of --key=value from a list of CLI arguments.
|
||||||
|
Return them as a dict. For example:
|
||||||
|
['--topic=AI LLMs', '--username=John'] => {'topic': 'AI LLMs', 'username': 'John'}
|
||||||
|
"""
|
||||||
|
overrides = {}
|
||||||
|
for arg in args_list:
|
||||||
|
if arg.startswith("--"):
|
||||||
|
# remove the leading --
|
||||||
|
trimmed = arg[2:]
|
||||||
|
if "=" in trimmed:
|
||||||
|
key, val = trimmed.split("=", 1)
|
||||||
|
overrides[key] = val
|
||||||
|
else:
|
||||||
|
# If someone passed something like --topic (no =),
|
||||||
|
# either handle differently or ignore
|
||||||
|
pass
|
||||||
|
return overrides
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
@@ -36,6 +37,7 @@ from crewai.tasks.task_output import TaskOutput
|
|||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||||
from crewai.tools.base_tool import Tool
|
from crewai.tools.base_tool import Tool
|
||||||
|
from crewai.types.crew_chat import ChatInputs
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
@@ -203,6 +205,14 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||||
)
|
)
|
||||||
|
chat_llm: Optional[Any] = Field(
|
||||||
|
default=None,
|
||||||
|
description="LLM used to handle chatting with the crew.",
|
||||||
|
)
|
||||||
|
chat_inputs: Optional[ChatInputs] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Holds descriptions of the crew as well as named inputs for chat usage.",
|
||||||
|
)
|
||||||
_knowledge: Optional[Knowledge] = PrivateAttr(
|
_knowledge: Optional[Knowledge] = PrivateAttr(
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
@@ -991,6 +1001,31 @@ class Crew(BaseModel):
|
|||||||
return self._knowledge.query(query)
|
return self._knowledge.query(query)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def fetch_inputs(self) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||||
|
Scans each task's 'description' + 'expected_output', and each agent's
|
||||||
|
'role', 'goal', and 'backstory'.
|
||||||
|
|
||||||
|
Returns a set of all discovered placeholder names.
|
||||||
|
"""
|
||||||
|
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||||
|
required_inputs: Set[str] = set()
|
||||||
|
|
||||||
|
# Scan tasks for inputs
|
||||||
|
for task in self.tasks:
|
||||||
|
# description and expected_output might contain e.g. {topic}, {user_name}, etc.
|
||||||
|
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||||
|
required_inputs.update(placeholder_pattern.findall(text))
|
||||||
|
|
||||||
|
# Scan agents for inputs
|
||||||
|
for agent in self.agents:
|
||||||
|
# role, goal, backstory might have placeholders like {role_detail}, etc.
|
||||||
|
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||||
|
required_inputs.update(placeholder_pattern.findall(text))
|
||||||
|
|
||||||
|
return required_inputs
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
"""Create a deep copy of the Crew."""
|
"""Create a deep copy of the Crew."""
|
||||||
|
|
||||||
@@ -1046,7 +1081,7 @@ class Crew(BaseModel):
|
|||||||
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def _interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Interpolates the inputs in the tasks and agents."""
|
"""Interpolates the inputs in the tasks and agents."""
|
||||||
[
|
[
|
||||||
task.interpolate_inputs(
|
task.interpolate_inputs_and_add_conversation_history(
|
||||||
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
|
# type: ignore # "interpolate_inputs" of "Task" does not return a value (it only ever returns None)
|
||||||
inputs
|
inputs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,20 +1,27 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import get_supported_openai_params
|
from litellm import Choices, get_supported_openai_params
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
|
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
class FilteredStream:
|
class FilteredStream:
|
||||||
def __init__(self, original_stream):
|
def __init__(self, original_stream):
|
||||||
@@ -23,6 +30,7 @@ class FilteredStream:
|
|||||||
|
|
||||||
def write(self, s) -> int:
|
def write(self, s) -> int:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
# Filter out extraneous messages from LiteLLM
|
||||||
if (
|
if (
|
||||||
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
||||||
in s
|
in s
|
||||||
@@ -84,11 +92,9 @@ def suppress_warnings():
|
|||||||
old_stderr = sys.stderr
|
old_stderr = sys.stderr
|
||||||
sys.stdout = FilteredStream(old_stdout)
|
sys.stdout = FilteredStream(old_stdout)
|
||||||
sys.stderr = FilteredStream(old_stderr)
|
sys.stderr = FilteredStream(old_stderr)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
# Restore stdout and stderr
|
|
||||||
sys.stdout = old_stdout
|
sys.stdout = old_stdout
|
||||||
sys.stderr = old_stderr
|
sys.stderr = old_stderr
|
||||||
|
|
||||||
@@ -109,13 +115,12 @@ class LLM:
|
|||||||
logit_bias: Optional[Dict[int, float]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
response_format: Optional[Dict[str, Any]] = None,
|
response_format: Optional[Dict[str, Any]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
logprobs: Optional[bool] = None,
|
logprobs: Optional[int] = None,
|
||||||
top_logprobs: Optional[int] = None,
|
top_logprobs: Optional[int] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@@ -137,19 +142,96 @@ class LLM:
|
|||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
self.context_window_size = 0
|
self.context_window_size = 0
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
|
# For safety, we disable passing init params to next calls
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
self.set_env_callbacks()
|
self.set_env_callbacks()
|
||||||
|
|
||||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
def to_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dict of all relevant parameters for serialization.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"model": self.model,
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"max_completion_tokens": self.max_completion_tokens,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"seed": self.seed,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"top_logprobs": self.top_logprobs,
|
||||||
|
"base_url": self.base_url,
|
||||||
|
"api_version": self.api_version,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"callbacks": self.callbacks,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "LLM":
|
||||||
|
"""
|
||||||
|
Create an LLM instance from a dict.
|
||||||
|
We assume the dict has all relevant keys that match what's in the constructor.
|
||||||
|
"""
|
||||||
|
known_fields = {}
|
||||||
|
known_fields["model"] = data.pop("model", None)
|
||||||
|
known_fields["timeout"] = data.pop("timeout", None)
|
||||||
|
known_fields["temperature"] = data.pop("temperature", None)
|
||||||
|
known_fields["top_p"] = data.pop("top_p", None)
|
||||||
|
known_fields["n"] = data.pop("n", None)
|
||||||
|
known_fields["stop"] = data.pop("stop", None)
|
||||||
|
known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None)
|
||||||
|
known_fields["max_tokens"] = data.pop("max_tokens", None)
|
||||||
|
known_fields["presence_penalty"] = data.pop("presence_penalty", None)
|
||||||
|
known_fields["frequency_penalty"] = data.pop("frequency_penalty", None)
|
||||||
|
known_fields["logit_bias"] = data.pop("logit_bias", None)
|
||||||
|
known_fields["response_format"] = data.pop("response_format", None)
|
||||||
|
known_fields["seed"] = data.pop("seed", None)
|
||||||
|
known_fields["logprobs"] = data.pop("logprobs", None)
|
||||||
|
known_fields["top_logprobs"] = data.pop("top_logprobs", None)
|
||||||
|
known_fields["base_url"] = data.pop("base_url", None)
|
||||||
|
known_fields["api_version"] = data.pop("api_version", None)
|
||||||
|
known_fields["api_key"] = data.pop("api_key", None)
|
||||||
|
known_fields["callbacks"] = data.pop("callbacks", None)
|
||||||
|
|
||||||
|
return cls(**known_fields, **data)
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
High-level call method that:
|
||||||
|
1) Calls litellm.completion
|
||||||
|
2) Checks for function/tool calls
|
||||||
|
3) If a tool call is found:
|
||||||
|
a) executes the function
|
||||||
|
b) returns the result
|
||||||
|
4) If no tool call, returns the text response
|
||||||
|
|
||||||
|
:param messages: The conversation messages
|
||||||
|
:param tools: Optional list of function schemas for function calling
|
||||||
|
:param callbacks: Optional list of callbacks
|
||||||
|
:param available_functions: A dictionary mapping function_name -> actual Python function
|
||||||
|
:return: Final text response from the LLM or the tool result
|
||||||
|
"""
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
if callbacks:
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# --- 1) Make the completion call
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -170,21 +252,65 @@ class LLM:
|
|||||||
"api_version": self.api_version,
|
"api_version": self.api_version,
|
||||||
"api_key": self.api_key,
|
"api_key": self.api_key,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
**self.kwargs,
|
"tools": tools, # pass the tool schema
|
||||||
}
|
}
|
||||||
|
|
||||||
# Remove None values to avoid passing unnecessary parameters
|
# Remove None values
|
||||||
params = {k: v for k, v in params.items() if v is not None}
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
return response["choices"][0]["message"]["content"]
|
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||||
|
0
|
||||||
|
].message
|
||||||
|
text_response = response_message.content or ""
|
||||||
|
tool_calls = getattr(response_message, "tool_calls", [])
|
||||||
|
|
||||||
|
# --- 2) If no tool calls, return the text response
|
||||||
|
if not tool_calls or not available_functions:
|
||||||
|
return text_response
|
||||||
|
|
||||||
|
# --- 3) Handle the tool call
|
||||||
|
tool_call = tool_calls[0]
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
|
||||||
|
if function_name in available_functions:
|
||||||
|
# Parse arguments
|
||||||
|
try:
|
||||||
|
function_args = json.loads(tool_call.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logging.warning(f"Failed to parse function arguments: {e}")
|
||||||
|
return text_response # Fallback to text response
|
||||||
|
|
||||||
|
fn = available_functions[function_name]
|
||||||
|
try:
|
||||||
|
# Call the actual tool function
|
||||||
|
result = fn(**function_args)
|
||||||
|
|
||||||
|
print(f"Result from function '{function_name}': {result}")
|
||||||
|
|
||||||
|
# Return the result directly
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error executing function '{function_name}': {e}"
|
||||||
|
)
|
||||||
|
return text_response # Fallback to text response
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Tool call requested unknown function '{function_name}'"
|
||||||
|
)
|
||||||
|
return text_response # Fallback to text response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Check if context length was exceeded, otherwise log
|
||||||
if not LLMContextLengthExceededException(
|
if not LLMContextLengthExceededException(
|
||||||
str(e)
|
str(e)
|
||||||
)._is_context_limit_error(str(e)):
|
)._is_context_limit_error(str(e)):
|
||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
|
# Re-raise the exception
|
||||||
raise # Re-raise the exception after logging
|
raise
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
try:
|
try:
|
||||||
@@ -203,7 +329,10 @@ class LLM:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
"""
|
||||||
|
Returns the context window size, using 75% of the maximum to avoid
|
||||||
|
cutting off messages mid-thread.
|
||||||
|
"""
|
||||||
if self.context_window_size != 0:
|
if self.context_window_size != 0:
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
@@ -216,6 +345,10 @@ class LLM:
|
|||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
def set_callbacks(self, callbacks: List[Any]):
|
def set_callbacks(self, callbacks: List[Any]):
|
||||||
|
"""
|
||||||
|
Attempt to keep a single set of callbacks in litellm by removing old
|
||||||
|
duplicates and adding new ones.
|
||||||
|
"""
|
||||||
callback_types = [type(callback) for callback in callbacks]
|
callback_types = [type(callback) for callback in callbacks]
|
||||||
for callback in litellm.success_callback[:]:
|
for callback in litellm.success_callback[:]:
|
||||||
if type(callback) in callback_types:
|
if type(callback) in callback_types:
|
||||||
@@ -230,34 +363,19 @@ class LLM:
|
|||||||
def set_env_callbacks(self):
|
def set_env_callbacks(self):
|
||||||
"""
|
"""
|
||||||
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||||
|
|
||||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
|
||||||
environment variables, which should contain comma-separated lists of callback names.
|
|
||||||
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
|
|
||||||
respectively.
|
|
||||||
|
|
||||||
If the environment variables are not set or are empty, the corresponding callback lists
|
|
||||||
will be set to empty lists.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
|
||||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
|
||||||
|
|
||||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
|
||||||
`litellm.failure_callback` to ["langfuse"].
|
|
||||||
"""
|
"""
|
||||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||||
success_callbacks = []
|
success_callbacks = []
|
||||||
if success_callbacks_str:
|
if success_callbacks_str:
|
||||||
success_callbacks = [
|
success_callbacks = [
|
||||||
callback.strip() for callback in success_callbacks_str.split(",")
|
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||||
failure_callbacks = []
|
failure_callbacks = []
|
||||||
if failure_callbacks_str:
|
if failure_callbacks_str:
|
||||||
failure_callbacks = [
|
failure_callbacks = [
|
||||||
callback.strip() for callback in failure_callbacks_str.split(",")
|
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
litellm.success_callback = success_callbacks
|
litellm.success_callback = success_callbacks
|
||||||
|
|||||||
@@ -449,9 +449,11 @@ class Task(BaseModel):
|
|||||||
tasks_slices = [self.description, output]
|
tasks_slices = [self.description, output]
|
||||||
return "\n".join(tasks_slices)
|
return "\n".join(tasks_slices)
|
||||||
|
|
||||||
def interpolate_inputs(self, inputs: Dict[str, Union[str, int, float]]) -> None:
|
|
||||||
"""Interpolate inputs into the task description, expected output, and output file path.
|
|
||||||
|
|
||||||
|
def interpolate_inputs_and_add_conversation_history(self, inputs: Dict[str, Union[str, int, float]]) -> None:
|
||||||
|
"""Interpolate inputs into the task description, expected output, and output file path.
|
||||||
|
Add conversation history if present.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: Dictionary mapping template variables to their values.
|
inputs: Dictionary mapping template variables to their values.
|
||||||
Supported value types are strings, integers, and floats.
|
Supported value types are strings, integers, and floats.
|
||||||
@@ -491,9 +493,33 @@ class Task(BaseModel):
|
|||||||
input_string=self._original_output_file, inputs=inputs
|
input_string=self._original_output_file, inputs=inputs
|
||||||
)
|
)
|
||||||
except (KeyError, ValueError) as e:
|
except (KeyError, ValueError) as e:
|
||||||
raise ValueError(
|
raise ValueError(f"Error interpolating output_file path: {str(e)}") from e
|
||||||
f"Error interpolating output_file path: {str(e)}"
|
|
||||||
) from e
|
if "crew_chat_messages" in inputs and inputs["crew_chat_messages"]:
|
||||||
|
# Fetch the conversation history instruction using self.i18n.slice
|
||||||
|
conversation_instruction = self.i18n.slice(
|
||||||
|
"conversation_history_instruction"
|
||||||
|
)
|
||||||
|
print("crew_chat_messages:", inputs["crew_chat_messages"])
|
||||||
|
|
||||||
|
# Ensure that inputs["crew_chat_messages"] is a string
|
||||||
|
crew_chat_messages_json = str(inputs["crew_chat_messages"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
crew_chat_messages = json.loads(crew_chat_messages_json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print("An error occurred while parsing crew chat messages:", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Process the messages to build conversation history
|
||||||
|
conversation_history = "\n".join(
|
||||||
|
f"{msg['role'].capitalize()}: {msg['content']}"
|
||||||
|
for msg in crew_chat_messages
|
||||||
|
if isinstance(msg, dict) and "role" in msg and "content" in msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the instruction and conversation history to the description
|
||||||
|
self.description += f"\n\n{conversation_instruction}\n\n{conversation_history}"
|
||||||
|
|
||||||
def interpolate_only(
|
def interpolate_only(
|
||||||
self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]]
|
self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]]
|
||||||
|
|||||||
@@ -23,7 +23,8 @@
|
|||||||
"summary": "This is a summary of our conversation so far:\n{merged_summary}",
|
"summary": "This is a summary of our conversation so far:\n{merged_summary}",
|
||||||
"manager_request": "Your best answer to your coworker asking you this, accounting for the context shared.",
|
"manager_request": "Your best answer to your coworker asking you this, accounting for the context shared.",
|
||||||
"formatted_task_instructions": "Ensure your final answer contains only the content in the following format: {output_format}\n\nEnsure the final output does not include any code block markers like ```json or ```python.",
|
"formatted_task_instructions": "Ensure your final answer contains only the content in the following format: {output_format}\n\nEnsure the final output does not include any code block markers like ```json or ```python.",
|
||||||
"human_feedback_classification": "Determine if the following feedback indicates that the user is satisfied or if further changes are needed. Respond with 'True' if further changes are needed, or 'False' if the user is satisfied. **Important** Do not include any additional commentary outside of your 'True' or 'False' response.\n\nFeedback: \"{feedback}\""
|
"human_feedback_classification": "Determine if the following feedback indicates that the user is satisfied or if further changes are needed. Respond with 'True' if further changes are needed, or 'False' if the user is satisfied. **Important** Do not include any additional commentary outside of your 'True' or 'False' response.\n\nFeedback: \"{feedback}\"",
|
||||||
|
"conversation_history_instruction": "You are a member of a crew collaborating to achieve a common goal. Your task is a specific action that contributes to this larger objective. For additional context, please review the conversation history between you and the user that led to the initiation of this crew. Use any relevant information or feedback from the conversation to inform your task execution and ensure your response aligns with both the immediate task and the crew's overall goals."
|
||||||
},
|
},
|
||||||
"errors": {
|
"errors": {
|
||||||
"force_final_answer_error": "You can't keep going, this was the best you could do.\n {formatted_answer.text}",
|
"force_final_answer_error": "You can't keep going, this was the best you could do.\n {formatted_answer.text}",
|
||||||
|
|||||||
44
src/crewai/types/crew_chat.py
Normal file
44
src/crewai/types/crew_chat.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ChatInputField(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a single required input for the crew, with a name and short description.
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"name": "topic",
|
||||||
|
"description": "The topic to focus on for the conversation"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="The name of the input field")
|
||||||
|
description: str = Field(
|
||||||
|
...,
|
||||||
|
description="A short description of the input field",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatInputs(BaseModel):
|
||||||
|
"""
|
||||||
|
Holds a high-level crew_description plus a list of ChatInputFields.
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"crew_name": "topic-based-qa",
|
||||||
|
"crew_description": "Use this crew for topic-based Q&A",
|
||||||
|
"inputs": [
|
||||||
|
{"name": "topic", "description": "The topic to focus on"},
|
||||||
|
{"name": "username", "description": "Name of the user"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
crew_name: str = Field(..., description="The name of the crew")
|
||||||
|
crew_description: str = Field(
|
||||||
|
...,
|
||||||
|
description="A description of the crew's purpose",
|
||||||
|
)
|
||||||
|
inputs: List[ChatInputField] = Field(
|
||||||
|
default_factory=list, description="A list of input fields for the crew"
|
||||||
|
)
|
||||||
215
src/crewai/utilities/llm_utils.py
Normal file
215
src/crewai/utilities/llm_utils.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||||
|
from crewai.cli.utils import read_toml
|
||||||
|
from crewai.cli.version import get_crewai_version
|
||||||
|
from crewai.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm(
|
||||||
|
llm_value: Union[str, LLM, Any, None] = None,
|
||||||
|
) -> Optional[LLM]:
|
||||||
|
"""
|
||||||
|
Creates or returns an LLM instance based on the given llm_value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_value (str | LLM | Any | None):
|
||||||
|
- str: The model name (e.g., "gpt-4").
|
||||||
|
- LLM: Already instantiated LLM, returned as-is.
|
||||||
|
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||||
|
- None: Use environment-based or fallback default model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An LLM instance if successful, or None if something fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1) If llm_value is already an LLM object, return it directly
|
||||||
|
if isinstance(llm_value, LLM):
|
||||||
|
return llm_value
|
||||||
|
|
||||||
|
# 2) If llm_value is a string (model name)
|
||||||
|
if isinstance(llm_value, str):
|
||||||
|
try:
|
||||||
|
created_llm = LLM(model=llm_value)
|
||||||
|
print(f"LLM created with model='{llm_value}'")
|
||||||
|
return created_llm
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 3) If llm_value is None, parse environment variables or use default
|
||||||
|
if llm_value is None:
|
||||||
|
return _llm_via_environment_or_fallback()
|
||||||
|
|
||||||
|
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||||
|
try:
|
||||||
|
# Extract attributes with explicit types
|
||||||
|
model = (
|
||||||
|
getattr(llm_value, "model_name", None)
|
||||||
|
or getattr(llm_value, "deployment_name", None)
|
||||||
|
or str(llm_value)
|
||||||
|
)
|
||||||
|
temperature: Optional[float] = getattr(llm_value, "temperature", None)
|
||||||
|
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
|
||||||
|
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
|
||||||
|
timeout: Optional[float] = getattr(llm_value, "timeout", None)
|
||||||
|
api_key: Optional[str] = getattr(llm_value, "api_key", None)
|
||||||
|
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||||
|
|
||||||
|
created_llm = LLM(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
logprobs=logprobs,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"LLM created with extracted parameters; "
|
||||||
|
f"model='{model}'"
|
||||||
|
)
|
||||||
|
return created_llm
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error instantiating LLM from unknown object type: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_llm() -> Optional[LLM]:
|
||||||
|
"""
|
||||||
|
Creates a Chat LLM with additional checks, such as verifying crewAI version
|
||||||
|
or reading from pyproject.toml. Then calls `create_llm(None, default_model)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_model (str): Fallback model if not set in environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of LLM or None if instantiation fails.
|
||||||
|
"""
|
||||||
|
print("[create_chat_llm] Checking environment and version info...")
|
||||||
|
|
||||||
|
crewai_version = get_crewai_version()
|
||||||
|
min_required_version = "0.87.0" # Update to latest if needed
|
||||||
|
|
||||||
|
pyproject_data = read_toml()
|
||||||
|
if pyproject_data.get("tool", {}).get("poetry") and (
|
||||||
|
version.parse(crewai_version) < version.parse(min_required_version)
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"You are running an older version of crewAI ({crewai_version}) that uses poetry.\n"
|
||||||
|
"Please run `crewai update` to switch to uv-based builds."
|
||||||
|
)
|
||||||
|
|
||||||
|
# After checks, simply call create_llm with None (meaning "use env or fallback"):
|
||||||
|
return create_llm(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||||
|
"""
|
||||||
|
Helper function: if llm_value is None, we load environment variables or fallback default model.
|
||||||
|
"""
|
||||||
|
model_name = (
|
||||||
|
os.environ.get("OPENAI_MODEL_NAME")
|
||||||
|
or os.environ.get("MODEL")
|
||||||
|
or DEFAULT_LLM_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize parameters with correct types
|
||||||
|
model: str = model_name
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
max_completion_tokens: Optional[int] = None
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
timeout: Optional[float] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
callbacks: List[Any] = []
|
||||||
|
|
||||||
|
# Optional base URL from env
|
||||||
|
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
|
||||||
|
if api_base:
|
||||||
|
base_url = api_base
|
||||||
|
|
||||||
|
# Initialize llm_params dictionary
|
||||||
|
llm_params: Dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"max_completion_tokens": max_completion_tokens,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
"timeout": timeout,
|
||||||
|
"api_key": api_key,
|
||||||
|
"base_url": base_url,
|
||||||
|
"api_version": api_version,
|
||||||
|
"presence_penalty": presence_penalty,
|
||||||
|
"frequency_penalty": frequency_penalty,
|
||||||
|
"top_p": top_p,
|
||||||
|
"n": n,
|
||||||
|
"stop": stop,
|
||||||
|
"logit_bias": logit_bias,
|
||||||
|
"response_format": response_format,
|
||||||
|
"seed": seed,
|
||||||
|
"top_logprobs": top_logprobs,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
}
|
||||||
|
|
||||||
|
UNACCEPTED_ATTRIBUTES = [
|
||||||
|
"AWS_ACCESS_KEY_ID",
|
||||||
|
"AWS_SECRET_ACCESS_KEY",
|
||||||
|
"AWS_REGION_NAME",
|
||||||
|
]
|
||||||
|
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
|
||||||
|
|
||||||
|
if set_provider in ENV_VARS:
|
||||||
|
for env_var in ENV_VARS[set_provider]:
|
||||||
|
key_name = env_var.get("key_name")
|
||||||
|
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
|
||||||
|
env_value = os.environ.get(key_name)
|
||||||
|
if env_value:
|
||||||
|
# Map environment variable names to recognized parameters
|
||||||
|
param_key = _normalize_key_name(key_name.lower())
|
||||||
|
llm_params[param_key] = env_value
|
||||||
|
elif isinstance(env_var, dict):
|
||||||
|
if env_var.get("default", False):
|
||||||
|
for key, value in env_var.items():
|
||||||
|
if key not in ["prompt", "key_name", "default"]:
|
||||||
|
if key in os.environ:
|
||||||
|
llm_params[key] = os.environ[key]
|
||||||
|
else:
|
||||||
|
print(f"Expected env_var to be a dictionary, but got {type(env_var)}")
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||||
|
|
||||||
|
# Try creating the LLM
|
||||||
|
try:
|
||||||
|
new_llm = LLM(**llm_params)
|
||||||
|
print(f"LLM created with model='{model_name}'")
|
||||||
|
return new_llm
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_key_name(key_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Maps environment variable names to recognized litellm parameter keys,
|
||||||
|
using patterns from LITELLM_PARAMS.
|
||||||
|
"""
|
||||||
|
for pattern in LITELLM_PARAMS:
|
||||||
|
if pattern in key_name:
|
||||||
|
return pattern
|
||||||
|
return key_name
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.utils import Usage
|
from litellm.types.utils import Usage
|
||||||
@@ -7,10 +8,16 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
|||||||
|
|
||||||
|
|
||||||
class TokenCalcHandler(CustomLogger):
|
class TokenCalcHandler(CustomLogger):
|
||||||
def __init__(self, token_cost_process: TokenProcess):
|
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||||
self.token_cost_process = token_cost_process
|
self.token_cost_process = token_cost_process
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(
|
||||||
|
self,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
response_obj: Dict[str, Any],
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> None:
|
||||||
if self.token_cost_process is None:
|
if self.token_cost_process is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -3087,6 +3087,28 @@ def test_hierarchical_verbose_false_manager_agent():
|
|||||||
assert not crew.manager_agent.verbose
|
assert not crew.manager_agent.verbose
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_inputs():
|
||||||
|
agent = Agent(
|
||||||
|
role="{role_detail} Researcher",
|
||||||
|
goal="Research on {topic}.",
|
||||||
|
backstory="Expert in {field}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Analyze the data on {topic}.",
|
||||||
|
expected_output="Summary of {topic} analysis.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
expected_placeholders = {"role_detail", "topic", "field"}
|
||||||
|
actual_placeholders = crew.fetch_inputs()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
actual_placeholders == expected_placeholders
|
||||||
|
), f"Expected {expected_placeholders}, but got {actual_placeholders}"
|
||||||
|
|
||||||
def test_task_tools_preserve_code_execution_tools():
|
def test_task_tools_preserve_code_execution_tools():
|
||||||
"""
|
"""
|
||||||
Test that task tools don't override code execution tools when allow_code_execution=True
|
Test that task tools don't override code execution tools when allow_code_execution=True
|
||||||
|
|||||||
Reference in New Issue
Block a user