mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix llm_utils.py and other type errors
This commit is contained in:
@@ -502,8 +502,11 @@ class Task(BaseModel):
|
|||||||
)
|
)
|
||||||
print("crew_chat_messages:", inputs["crew_chat_messages"])
|
print("crew_chat_messages:", inputs["crew_chat_messages"])
|
||||||
|
|
||||||
|
# Ensure that inputs["crew_chat_messages"] is a string
|
||||||
|
crew_chat_messages_json = str(inputs["crew_chat_messages"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
crew_chat_messages = json.loads(inputs["crew_chat_messages"])
|
crew_chat_messages = json.loads(crew_chat_messages_json)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print("An error occurred while parsing crew chat messages:", e)
|
print("An error occurred while parsing crew chat messages:", e)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@@ -21,8 +21,6 @@ def create_llm(
|
|||||||
- LLM: Already instantiated LLM, returned as-is.
|
- LLM: Already instantiated LLM, returned as-is.
|
||||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||||
- None: Use environment-based or fallback default model.
|
- None: Use environment-based or fallback default model.
|
||||||
default_model (str): The fallback model name to use if llm_value is None
|
|
||||||
and no environment variable is set.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An LLM instance if successful, or None if something fails.
|
An LLM instance if successful, or None if something fails.
|
||||||
@@ -46,30 +44,33 @@ def create_llm(
|
|||||||
if llm_value is None:
|
if llm_value is None:
|
||||||
return _llm_via_environment_or_fallback()
|
return _llm_via_environment_or_fallback()
|
||||||
|
|
||||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object (like a config)
|
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||||
# e.g. follow the approach used in agent.py
|
|
||||||
try:
|
try:
|
||||||
llm_params = {
|
# Extract attributes with explicit types
|
||||||
"model": (
|
model = (
|
||||||
getattr(llm_value, "model_name", None)
|
getattr(llm_value, "model_name", None)
|
||||||
or getattr(llm_value, "deployment_name", None)
|
or getattr(llm_value, "deployment_name", None)
|
||||||
or str(llm_value)
|
or str(llm_value)
|
||||||
),
|
)
|
||||||
"temperature": getattr(llm_value, "temperature", None),
|
temperature: Optional[float] = getattr(llm_value, "temperature", None)
|
||||||
"max_tokens": getattr(llm_value, "max_tokens", None),
|
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
|
||||||
"logprobs": getattr(llm_value, "logprobs", None),
|
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
|
||||||
"timeout": getattr(llm_value, "timeout", None),
|
timeout: Optional[float] = getattr(llm_value, "timeout", None)
|
||||||
"max_retries": getattr(llm_value, "max_retries", None),
|
api_key: Optional[str] = getattr(llm_value, "api_key", None)
|
||||||
"api_key": getattr(llm_value, "api_key", None),
|
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||||
"base_url": getattr(llm_value, "base_url", None),
|
|
||||||
"organization": getattr(llm_value, "organization", None),
|
created_llm = LLM(
|
||||||
}
|
model=model,
|
||||||
# Remove None values
|
temperature=temperature,
|
||||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
max_tokens=max_tokens,
|
||||||
created_llm = LLM(**llm_params)
|
logprobs=logprobs,
|
||||||
|
timeout=timeout,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"LLM created with extracted parameters; "
|
"LLM created with extracted parameters; "
|
||||||
f"model='{llm_params.get('model', 'UNKNOWN')}'"
|
f"model='{model}'"
|
||||||
)
|
)
|
||||||
return created_llm
|
return created_llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -77,7 +78,7 @@ def create_llm(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def create_chat_llm(default_model: str = "gpt-4") -> Optional[LLM]:
|
def create_chat_llm() -> Optional[LLM]:
|
||||||
"""
|
"""
|
||||||
Creates a Chat LLM with additional checks, such as verifying crewAI version
|
Creates a Chat LLM with additional checks, such as verifying crewAI version
|
||||||
or reading from pyproject.toml. Then calls `create_llm(None, default_model)`.
|
or reading from pyproject.toml. Then calls `create_llm(None, default_model)`.
|
||||||
@@ -115,12 +116,55 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
|||||||
or os.environ.get("MODEL")
|
or os.environ.get("MODEL")
|
||||||
or DEFAULT_LLM_MODEL
|
or DEFAULT_LLM_MODEL
|
||||||
)
|
)
|
||||||
llm_params = {"model": model_name}
|
|
||||||
|
# Initialize parameters with correct types
|
||||||
|
model: str = model_name
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
max_completion_tokens: Optional[int] = None
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
timeout: Optional[float] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
callbacks: List[Any] = []
|
||||||
|
|
||||||
# Optional base URL from env
|
# Optional base URL from env
|
||||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
|
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL")
|
||||||
if api_base:
|
if api_base:
|
||||||
llm_params["base_url"] = api_base
|
base_url = api_base
|
||||||
|
|
||||||
|
# Initialize llm_params dictionary
|
||||||
|
llm_params: Dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"max_completion_tokens": max_completion_tokens,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
"timeout": timeout,
|
||||||
|
"api_key": api_key,
|
||||||
|
"base_url": base_url,
|
||||||
|
"api_version": api_version,
|
||||||
|
"presence_penalty": presence_penalty,
|
||||||
|
"frequency_penalty": frequency_penalty,
|
||||||
|
"top_p": top_p,
|
||||||
|
"n": n,
|
||||||
|
"stop": stop,
|
||||||
|
"logit_bias": logit_bias,
|
||||||
|
"response_format": response_format,
|
||||||
|
"seed": seed,
|
||||||
|
"top_logprobs": top_logprobs,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
}
|
||||||
|
|
||||||
UNACCEPTED_ATTRIBUTES = [
|
UNACCEPTED_ATTRIBUTES = [
|
||||||
"AWS_ACCESS_KEY_ID",
|
"AWS_ACCESS_KEY_ID",
|
||||||
@@ -135,14 +179,17 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
|||||||
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
|
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
|
||||||
env_value = os.environ.get(key_name)
|
env_value = os.environ.get(key_name)
|
||||||
if env_value:
|
if env_value:
|
||||||
# Map environment variable names to recognized LITELLM_PARAMS if any
|
# Map environment variable names to recognized parameters
|
||||||
param_key = _normalize_key_name(key_name.lower())
|
param_key = _normalize_key_name(key_name.lower())
|
||||||
llm_params[param_key] = env_value
|
llm_params[param_key] = env_value
|
||||||
elif env_var.get("default", False):
|
elif env_var.get("default", False):
|
||||||
for key, value in env_var.items():
|
for key, value in env_var.items():
|
||||||
if key not in ["prompt", "key_name", "default"]:
|
if key not in ["prompt", "key_name", "default"]:
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
llm_params[key] = value
|
llm_params[key] = os.environ[key]
|
||||||
|
|
||||||
|
# Remove None values
|
||||||
|
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||||
|
|
||||||
# Try creating the LLM
|
# Try creating the LLM
|
||||||
try:
|
try:
|
||||||
@@ -150,7 +197,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
|||||||
print(f"LLM created with model='{model_name}'")
|
print(f"LLM created with model='{model_name}'")
|
||||||
return new_llm
|
return new_llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error instantiating LLM from environment/fallback: {e}")
|
print(f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.utils import Usage
|
from litellm.types.utils import Usage
|
||||||
@@ -7,10 +8,16 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
|||||||
|
|
||||||
|
|
||||||
class TokenCalcHandler(CustomLogger):
|
class TokenCalcHandler(CustomLogger):
|
||||||
def __init__(self, token_cost_process: TokenProcess):
|
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||||
self.token_cost_process = token_cost_process
|
self.token_cost_process = token_cost_process
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(
|
||||||
|
self,
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
response_obj: Dict[str, Any],
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
) -> None:
|
||||||
if self.token_cost_process is None:
|
if self.token_cost_process is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user