mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
Compare commits
3 Commits
39bdc7e4d4
...
bfb578d506
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfb578d506 | ||
|
|
5d3c34b3ea | ||
|
|
8ec2eb7d72 |
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -21,11 +23,11 @@ from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.token_process import TokenProcess
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
agentops = None
|
||||
|
||||
@@ -132,7 +134,6 @@ class Agent(BaseAgent):
|
||||
"verbose": verbose,
|
||||
"max_rpm": max_rpm,
|
||||
"tools": processed_tools,
|
||||
"llm": processed_llm,
|
||||
"max_iter": max_iter if max_iter is not None else 25,
|
||||
"function_calling_llm": function_calling_llm,
|
||||
"step_callback": step_callback,
|
||||
@@ -148,11 +149,14 @@ class Agent(BaseAgent):
|
||||
self._original_goal = goal
|
||||
self._original_backstory = backstory
|
||||
|
||||
# Set LLM after base initialization to ensure proper model handling
|
||||
self.llm = processed_llm
|
||||
|
||||
# Initialize private attributes
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm:
|
||||
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||
self._token_process = TokenProcess() # type: ignore # Known type mismatch between utilities and agent_builder
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
@@ -530,6 +534,32 @@ class Agent(BaseAgent):
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
)
|
||||
|
||||
# Ensure LLM is initialized with proper error handling
|
||||
try:
|
||||
if not self.llm:
|
||||
self.llm = LLM(model="gpt-4")
|
||||
if hasattr(self, '_logger'):
|
||||
self._logger.debug("Initialized default LLM with gpt-4 model")
|
||||
except Exception as e:
|
||||
if hasattr(self, '_logger'):
|
||||
self._logger.error(f"Failed to initialize LLM: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create token callback with proper error handling
|
||||
try:
|
||||
token_callback = None
|
||||
if hasattr(self, '_token_process'):
|
||||
token_callback = TokenCalcHandler(self._token_process)
|
||||
except Exception as e:
|
||||
if hasattr(self, '_logger'):
|
||||
self._logger.warning(f"Failed to create token callback: {str(e)}")
|
||||
token_callback = None
|
||||
|
||||
# Initialize callbacks list
|
||||
executor_callbacks = []
|
||||
if token_callback:
|
||||
executor_callbacks.append(token_callback)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
task=task,
|
||||
@@ -547,9 +577,9 @@ class Agent(BaseAgent):
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
respect_context_window=self.respect_context_window,
|
||||
request_within_rpm_limit=(
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
self._rpm_controller.check_or_wait if (hasattr(self, '_rpm_controller') and self._rpm_controller is not None) else None
|
||||
),
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
callbacks=executor_callbacks,
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
from typing import Any
|
||||
|
||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
@@ -192,7 +193,7 @@ def download_data(response):
|
||||
data_chunks = []
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as progress_bar:
|
||||
) as progress_bar: # type: Any
|
||||
for chunk in response.iter_content(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -23,14 +23,25 @@ class CrewOutput(BaseModel):
|
||||
)
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(
|
||||
self,
|
||||
*,
|
||||
include: Union[set[str], None] = None,
|
||||
exclude: Union[set[str], None] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
encoder: Optional[Callable[[Any], Any]] = None,
|
||||
models_as_dict: bool = True,
|
||||
**dumps_kwargs: Any,
|
||||
) -> str:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
)
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -106,7 +106,12 @@ class FlowPlot:
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
add_nodes_to_network(
|
||||
net,
|
||||
flow=self.flow,
|
||||
pos=node_positions,
|
||||
node_styles=self.node_styles
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
import litellm
|
||||
@@ -93,10 +95,33 @@ def suppress_warnings():
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM:
|
||||
class LLM(BaseModel):
|
||||
model: str = "gpt-4" # Set default model
|
||||
timeout: Optional[Union[float, int]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
seed: Optional[int] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
callbacks: Optional[List[Any]] = None
|
||||
context_window_size: Optional[int] = None
|
||||
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, 'LLM'],
|
||||
model: Optional[Union[str, 'LLM']] = "gpt-4",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
@@ -114,12 +139,103 @@ class LLM:
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
**kwargs,
|
||||
):
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
context_window_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Initialize with default values
|
||||
init_dict = {
|
||||
"model": model if isinstance(model, str) else "gpt-4",
|
||||
"timeout": timeout,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"n": n,
|
||||
"stop": stop,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"logit_bias": logit_bias,
|
||||
"response_format": response_format,
|
||||
"seed": seed,
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"base_url": base_url,
|
||||
"api_version": api_version,
|
||||
"api_key": api_key,
|
||||
"callbacks": callbacks,
|
||||
"context_window_size": context_window_size,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
super().__init__(**init_dict)
|
||||
|
||||
# Initialize model with default value
|
||||
self.model = "gpt-4" # Default fallback
|
||||
|
||||
# Extract and validate model name
|
||||
if isinstance(model, LLM):
|
||||
# Extract and validate model name from LLM instance
|
||||
if hasattr(model, 'model'):
|
||||
if isinstance(model.model, str):
|
||||
self.model = model.model
|
||||
else:
|
||||
# Try to extract string model name from nested LLM
|
||||
if isinstance(model.model, LLM):
|
||||
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||
else:
|
||||
# Extract and validate model name for non-LLM instances
|
||||
if not isinstance(model, str):
|
||||
if self.logger:
|
||||
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||
if model is not None:
|
||||
if hasattr(model, 'model_name'):
|
||||
model_name = getattr(model, 'model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
elif hasattr(model, 'model'):
|
||||
model_attr = getattr(model, 'model', None)
|
||||
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||
elif hasattr(model, '_model_name'):
|
||||
model_name = getattr(model, '_model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback
|
||||
if self.logger:
|
||||
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback for None
|
||||
if self.logger:
|
||||
self.logger.warning("Model is None, using default: gpt-4")
|
||||
else:
|
||||
self.model = str(model) # Ensure it's a string
|
||||
|
||||
# If model is an LLM instance, copy its configuration
|
||||
if isinstance(model, LLM):
|
||||
self.model = model.model
|
||||
# Extract and validate model name first
|
||||
if hasattr(model, 'model'):
|
||||
if isinstance(model.model, str):
|
||||
self.model = model.model
|
||||
else:
|
||||
# Try to extract string model name from nested LLM
|
||||
if isinstance(model.model, LLM):
|
||||
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||
|
||||
# Copy other configuration
|
||||
self.timeout = model.timeout
|
||||
self.temperature = model.temperature
|
||||
self.top_p = model.top_p
|
||||
@@ -140,8 +256,44 @@ class LLM:
|
||||
self.callbacks = model.callbacks
|
||||
self.context_window_size = model.context_window_size
|
||||
self.kwargs = model.kwargs
|
||||
|
||||
# Final validation of model name
|
||||
if not isinstance(self.model, str):
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
|
||||
else:
|
||||
self.model = model
|
||||
# Extract and validate model name for non-LLM instances
|
||||
if not isinstance(model, str):
|
||||
if self.logger:
|
||||
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||
if model is not None:
|
||||
if hasattr(model, 'model_name'):
|
||||
model_name = getattr(model, 'model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
elif hasattr(model, 'model'):
|
||||
model_attr = getattr(model, 'model', None)
|
||||
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||
elif hasattr(model, '_model_name'):
|
||||
model_name = getattr(model, '_model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback
|
||||
if self.logger:
|
||||
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback for None
|
||||
if self.logger:
|
||||
self.logger.warning("Model is None, using default: gpt-4")
|
||||
else:
|
||||
self.model = str(model) # Ensure it's a string
|
||||
|
||||
# Final validation
|
||||
if not isinstance(self.model, str):
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
|
||||
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
@@ -163,69 +315,155 @@ class LLM:
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Ensure model is a string after initialization
|
||||
if not isinstance(self.model, str):
|
||||
self.model = "gpt-4"
|
||||
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
def call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
callbacks: Optional[List[Any]] = None
|
||||
) -> str:
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# Ensure model is a string and set default
|
||||
model_name = "gpt-4" # Default model
|
||||
|
||||
# Extract model name from self.model
|
||||
current = self.model
|
||||
while current is not None:
|
||||
if isinstance(current, str):
|
||||
model_name = current
|
||||
break
|
||||
elif isinstance(current, LLM):
|
||||
current = current.model
|
||||
elif hasattr(current, "model"):
|
||||
current = getattr(current, "model")
|
||||
else:
|
||||
break
|
||||
# Store original model to restore later
|
||||
original_model = self.model
|
||||
|
||||
try:
|
||||
# Ensure model is a string before making the call
|
||||
if not isinstance(self.model, str):
|
||||
if self.logger:
|
||||
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
|
||||
if isinstance(self.model, LLM):
|
||||
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
|
||||
elif hasattr(self.model, 'model_name'):
|
||||
self.model = str(self.model.model_name)
|
||||
elif hasattr(self.model, 'model'):
|
||||
if isinstance(self.model.model, str):
|
||||
self.model = str(self.model.model)
|
||||
elif hasattr(self.model.model, 'model_name'):
|
||||
self.model = str(self.model.model.model_name)
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
|
||||
|
||||
# Create base params with validated model name
|
||||
# Extract model name string
|
||||
model_name = None
|
||||
if isinstance(self.model, str):
|
||||
model_name = self.model
|
||||
elif hasattr(self.model, 'model_name'):
|
||||
model_name = str(self.model.model_name)
|
||||
elif hasattr(self.model, 'model'):
|
||||
if isinstance(self.model.model, str):
|
||||
model_name = str(self.model.model)
|
||||
elif hasattr(self.model.model, 'model_name'):
|
||||
model_name = str(self.model.model.model_name)
|
||||
|
||||
if not model_name:
|
||||
model_name = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||
|
||||
# Set parameters for litellm
|
||||
# Build base params dict with required fields
|
||||
params = {
|
||||
"model": model_name,
|
||||
"custom_llm_provider": "openai",
|
||||
"messages": messages,
|
||||
"stream": False # Always set stream to False
|
||||
"stream": False,
|
||||
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
}
|
||||
|
||||
# Add API configuration
|
||||
if self.logger:
|
||||
self.logger.debug(f"Using model parameters: {params}")
|
||||
|
||||
# Add API configuration if available
|
||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
params["api_key"] = api_key
|
||||
|
||||
# Define optional parameters
|
||||
optional_params = {
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
}
|
||||
|
||||
# Try to get supported parameters for the model
|
||||
try:
|
||||
supported_params = get_supported_openai_params(self.model)
|
||||
optional_params = {}
|
||||
|
||||
if supported_params:
|
||||
param_mapping = {
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs
|
||||
}
|
||||
|
||||
# Only add parameters that are supported and not None
|
||||
optional_params = {
|
||||
k: v for k, v in param_mapping.items()
|
||||
if k in supported_params and v is not None
|
||||
}
|
||||
if "logprobs" in supported_params and self.logprobs is not None:
|
||||
optional_params["logprobs"] = self.logprobs
|
||||
if "top_logprobs" in supported_params and self.top_logprobs is not None:
|
||||
optional_params["top_logprobs"] = self.top_logprobs
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
|
||||
# If we can't get supported params, just add non-None parameters
|
||||
param_mapping = {
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs
|
||||
}
|
||||
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
|
||||
|
||||
# Update params with optional parameters
|
||||
params.update(optional_params)
|
||||
|
||||
# Add API endpoint configuration if available
|
||||
if self.base_url:
|
||||
optional_params["api_base"] = self.base_url
|
||||
params["api_base"] = self.base_url
|
||||
if self.api_version:
|
||||
optional_params["api_version"] = self.api_version
|
||||
params["api_version"] = self.api_version
|
||||
|
||||
# Final validation of model parameter
|
||||
if not isinstance(params["model"], str):
|
||||
if self.logger:
|
||||
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
|
||||
params["model"] = "gpt-4"
|
||||
|
||||
# Update params with non-None optional parameters
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
@@ -238,21 +476,38 @@ class LLM:
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
# Extract usage metrics
|
||||
usage = response.get("usage", {})
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "update_token_usage"):
|
||||
callback.update_token_usage(usage)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
|
||||
raise # Re-raise the exception after logging
|
||||
finally:
|
||||
# Always restore the original model object
|
||||
self.model = original_model
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports function calling, False otherwise
|
||||
"""
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return "response_format" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
if self.logger:
|
||||
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -264,33 +519,47 @@ class LLM:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
if self.logger:
|
||||
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model.
|
||||
|
||||
Returns:
|
||||
int: The context window size in tokens
|
||||
"""
|
||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
if self.context_window_size is not None and self.context_window_size != 0:
|
||||
return int(self.context_window_size)
|
||||
|
||||
self.context_window_size = int(
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
||||
)
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if self.model.startswith(key):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
|
||||
if isinstance(self.model, str):
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if self.model.startswith(key):
|
||||
window_size = value
|
||||
break
|
||||
|
||||
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
|
||||
"""Set callbacks for the LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Optional list of callback functions. If None, no callbacks will be set.
|
||||
"""
|
||||
if callbacks is not None:
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -34,8 +34,19 @@ class TaskOutput(BaseModel):
|
||||
self.summary = f"{excerpt}..."
|
||||
return self
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(
|
||||
self,
|
||||
*,
|
||||
include: Union[set[str], None] = None,
|
||||
exclude: Union[set[str], None] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
encoder: Optional[Callable[[Any], Any]] = None,
|
||||
models_as_dict: bool = True,
|
||||
**dumps_kwargs: Any,
|
||||
) -> str:
|
||||
if self.output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"""
|
||||
@@ -45,7 +56,7 @@ class TaskOutput(BaseModel):
|
||||
"""
|
||||
)
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -142,7 +142,12 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields)
|
||||
return create_model(
|
||||
schema_name,
|
||||
__base__=BaseModel,
|
||||
__config__=None,
|
||||
**{k: v for k, v in fields.items()}
|
||||
)
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
@@ -170,7 +175,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
def _parse_args(self, raw_args: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -178,6 +183,9 @@ class CrewStructuredTool:
|
||||
|
||||
Returns:
|
||||
The validated arguments as a dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If the arguments cannot be parsed or fail validation
|
||||
"""
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
@@ -195,8 +203,8 @@ class CrewStructuredTool:
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, dict],
|
||||
config: Optional[dict] = None,
|
||||
input: Union[str, dict[str, Any]],
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
@@ -229,7 +237,10 @@ class CrewStructuredTool:
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
||||
self,
|
||||
input: Union[str, dict[str, Any]],
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
@@ -10,8 +10,24 @@ class Logger(BaseModel):
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
def log(self, level, message, color="bold_yellow"):
|
||||
if self.verbose:
|
||||
if self.verbose or level.upper() in ["WARNING", "ERROR"]:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self._printer.print(
|
||||
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
|
||||
)
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log a debug message if verbose is enabled."""
|
||||
self.log("debug", message, color="bold_blue")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log an info message if verbose is enabled."""
|
||||
self.log("info", message, color="bold_green")
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning message."""
|
||||
self.log("warning", message, color="bold_yellow")
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error message."""
|
||||
self.log("error", message, color="bold_red")
|
||||
|
||||
@@ -1,44 +1,63 @@
|
||||
"""Token processing utility for tracking and managing token usage."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
class TokenProcess:
|
||||
"""Handles token processing and tracking for agents."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the token processor."""
|
||||
self._token_count = 0
|
||||
self._last_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
self._successful_requests = 0
|
||||
|
||||
def update_token_count(self, count: int) -> None:
|
||||
"""Update the token count.
|
||||
def sum_prompt_tokens(self, count: int) -> None:
|
||||
"""Add to prompt token count.
|
||||
|
||||
Args:
|
||||
count (int): Number of tokens to add to the count
|
||||
count (int): Number of prompt tokens to add
|
||||
"""
|
||||
self._token_count += count
|
||||
self._last_tokens = count
|
||||
self._prompt_tokens += count
|
||||
self._total_tokens += count
|
||||
|
||||
def get_token_count(self) -> int:
|
||||
"""Get the total token count.
|
||||
def sum_completion_tokens(self, count: int) -> None:
|
||||
"""Add to completion token count.
|
||||
|
||||
Returns:
|
||||
int: Total number of tokens processed
|
||||
Args:
|
||||
count (int): Number of completion tokens to add
|
||||
"""
|
||||
return self._token_count
|
||||
self._completion_tokens += count
|
||||
self._total_tokens += count
|
||||
|
||||
def get_last_tokens(self) -> int:
|
||||
"""Get the number of tokens from the last update.
|
||||
def sum_cached_prompt_tokens(self, count: int) -> None:
|
||||
"""Add to cached prompt token count.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens from last update
|
||||
Args:
|
||||
count (int): Number of cached prompt tokens to add
|
||||
"""
|
||||
return self._last_tokens
|
||||
self._cached_prompt_tokens += count
|
||||
|
||||
def sum_successful_requests(self, count: int) -> None:
|
||||
"""Add to successful requests count.
|
||||
|
||||
Args:
|
||||
count (int): Number of successful requests to add
|
||||
"""
|
||||
self._successful_requests += count
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the token counts to zero."""
|
||||
self._token_count = 0
|
||||
self._last_tokens = 0
|
||||
"""Reset all token counts to zero."""
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
self._successful_requests = 0
|
||||
|
||||
def get_summary(self) -> UsageMetrics:
|
||||
"""Get a summary of token usage.
|
||||
@@ -47,9 +66,9 @@ class TokenProcess:
|
||||
UsageMetrics: Object containing token usage metrics
|
||||
"""
|
||||
return UsageMetrics(
|
||||
total_tokens=self._token_count,
|
||||
prompt_tokens=0, # These will be set by the LLM handler
|
||||
cached_prompt_tokens=0,
|
||||
completion_tokens=self._last_tokens,
|
||||
successful_requests=1 if self._token_count > 0 else 0
|
||||
total_tokens=self._total_tokens,
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
cached_prompt_tokens=self._cached_prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
successful_requests=self._successful_requests
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user