mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Fix token tracking in LangChainAgentAdapter and refactor token_process attribute to be public
This commit is contained in:
@@ -82,13 +82,18 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__hash__ = object.__hash__ # type: ignore
|
__hash__ = object.__hash__ # type: ignore
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
|
|
||||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
formatting_errors: int = Field(
|
formatting_errors: int = Field(
|
||||||
default=0, description="Number of formatting errors."
|
default=0, description="Number of formatting errors."
|
||||||
@@ -198,8 +203,6 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
self._rpm_controller = RPMController(
|
self._rpm_controller = RPMController(
|
||||||
max_rpm=self.max_rpm, logger=self._logger
|
max_rpm=self.max_rpm, logger=self._logger
|
||||||
)
|
)
|
||||||
if not self._token_process:
|
|
||||||
self._token_process = TokenProcess()
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -219,8 +222,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
self._rpm_controller = RPMController(
|
self._rpm_controller = RPMController(
|
||||||
max_rpm=self.max_rpm, logger=self._logger
|
max_rpm=self.max_rpm, logger=self._logger
|
||||||
)
|
)
|
||||||
if not self._token_process:
|
|
||||||
self._token_process = TokenProcess()
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -268,7 +270,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"_logger",
|
"_logger",
|
||||||
"_rpm_controller",
|
"_rpm_controller",
|
||||||
"_request_within_rpm_limit",
|
"_request_within_rpm_limit",
|
||||||
"_token_process",
|
"token_process",
|
||||||
"agent_executor",
|
"agent_executor",
|
||||||
"tools",
|
"tools",
|
||||||
"tools_handler",
|
"tools_handler",
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
from typing import Any, List, Optional, Type, Union, cast
|
from typing import Any, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
from crewai.tools.base_tool import Tool
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langchain_core.tools import Tool as LangChainTool # type: ignore
|
|
||||||
except ImportError:
|
|
||||||
LangChainTool = None
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
from crewai.tools.base_tool import Tool
|
||||||
from crewai.utilities.converter import Converter, generate_model_description
|
from crewai.utilities.converter import Converter, generate_model_description
|
||||||
|
from crewai.utilities.token_counter_callback import (
|
||||||
|
LangChainTokenCounter,
|
||||||
|
LiteLLMTokenCounter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LangChainAgentAdapter(BaseAgent):
|
class LangChainAgentAdapter(BaseAgent):
|
||||||
@@ -51,6 +50,8 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
i18n: Any = None
|
i18n: Any = None
|
||||||
crew: Any = None
|
crew: Any = None
|
||||||
knowledge: Any = None
|
knowledge: Any = None
|
||||||
|
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||||
|
token_callback: Optional[Any] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@@ -72,16 +73,35 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
def _extract_text(self, message: Any) -> str:
|
def _extract_text(self, message: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Helper to extract plain text from a message object.
|
Helper to extract plain text from a message object.
|
||||||
This checks if the message is a dict with a "content" key, or has a "content" attribute.
|
This checks if the message is a dict with a "content" key, or has a "content" attribute,
|
||||||
|
or if it's a tuple from LangGraph's message format.
|
||||||
"""
|
"""
|
||||||
if isinstance(message, dict) and "content" in message:
|
# Handle LangGraph message tuple format (role, content)
|
||||||
return message["content"]
|
if isinstance(message, tuple) and len(message) == 2:
|
||||||
|
return str(message[1])
|
||||||
|
|
||||||
|
# Handle dictionary with content key
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
if "content" in message:
|
||||||
|
return message["content"]
|
||||||
|
# Handle LangGraph message format with additional metadata
|
||||||
|
elif "messages" in message and message["messages"]:
|
||||||
|
last_message = message["messages"][-1]
|
||||||
|
if isinstance(last_message, tuple) and len(last_message) == 2:
|
||||||
|
return str(last_message[1])
|
||||||
|
return self._extract_text(last_message)
|
||||||
|
|
||||||
|
# Handle object with content attribute
|
||||||
elif hasattr(message, "content") and isinstance(
|
elif hasattr(message, "content") and isinstance(
|
||||||
getattr(message, "content"), str
|
getattr(message, "content"), str
|
||||||
):
|
):
|
||||||
return getattr(message, "content")
|
return getattr(message, "content")
|
||||||
|
|
||||||
|
# Handle string directly
|
||||||
elif isinstance(message, str):
|
elif isinstance(message, str):
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
return str(message)
|
return str(message)
|
||||||
|
|
||||||
def execute_task(
|
def execute_task(
|
||||||
@@ -161,19 +181,77 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
# Initialize token tracking callback if needed
|
||||||
|
if hasattr(self, "token_process") and self.token_callback is None:
|
||||||
|
# Determine if we're using LangChain or LiteLLM based on the agent type
|
||||||
|
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||||
|
self.langchain_agent.client, "callbacks"
|
||||||
|
):
|
||||||
|
# This is likely a LiteLLM-based agent
|
||||||
|
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add our callback to the LLM directly
|
||||||
|
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||||
|
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||||
|
else:
|
||||||
|
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||||
|
else:
|
||||||
|
# This is likely a LangChain-based agent
|
||||||
|
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add callback to the LangChain model
|
||||||
|
if hasattr(self.langchain_agent, "callbacks"):
|
||||||
|
if self.langchain_agent.callbacks is None:
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
# For direct LLM models
|
||||||
|
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||||
|
self.langchain_agent.llm, "callbacks"
|
||||||
|
):
|
||||||
|
if self.langchain_agent.llm.callbacks is None:
|
||||||
|
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||||
|
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||||
|
# Direct LLM case
|
||||||
|
elif not hasattr(self.langchain_agent, "agent"):
|
||||||
|
# This might be a direct LLM, not an agent
|
||||||
|
if (
|
||||||
|
not hasattr(self.langchain_agent, "callbacks")
|
||||||
|
or self.langchain_agent.callbacks is None
|
||||||
|
):
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
|
||||||
init_state = {"messages": [("user", task_prompt)]}
|
init_state = {"messages": [("user", task_prompt)]}
|
||||||
|
|
||||||
|
# Estimate input tokens for tracking
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters (better than word count)
|
||||||
|
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
|
||||||
|
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||||
|
|
||||||
state = self.agent_executor.invoke(init_state)
|
state = self.agent_executor.invoke(init_state)
|
||||||
|
|
||||||
|
# Extract output from state based on its structure
|
||||||
if "structured_response" in state:
|
if "structured_response" in state:
|
||||||
current_output = state["structured_response"]
|
current_output = state["structured_response"]
|
||||||
elif "messages" in state and state["messages"]:
|
elif "messages" in state and state["messages"]:
|
||||||
last_message = state["messages"][-1]
|
last_message = state["messages"][-1]
|
||||||
if isinstance(last_message, tuple):
|
current_output = self._extract_text(last_message)
|
||||||
current_output = last_message[1]
|
elif "output" in state:
|
||||||
else:
|
current_output = str(state["output"])
|
||||||
current_output = self._extract_text(last_message)
|
|
||||||
else:
|
else:
|
||||||
current_output = ""
|
# Fallback to extracting text from the entire state
|
||||||
|
current_output = self._extract_text(state)
|
||||||
|
|
||||||
|
# Estimate completion tokens for tracking if we don't have actual counts
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters
|
||||||
|
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
|
||||||
|
self.token_process.sum_completion_tokens(estimated_completion_tokens)
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
if task.human_input:
|
if task.human_input:
|
||||||
current_output = self._handle_human_feedback(current_output)
|
current_output = self._handle_human_feedback(current_output)
|
||||||
@@ -203,20 +281,40 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
||||||
f"Updated answer:"
|
f"Updated answer:"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Estimate input tokens for tracking
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters
|
||||||
|
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
|
||||||
|
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_state = self.agent_executor.invoke(
|
new_state = self.agent_executor.invoke(
|
||||||
{"messages": [("user", new_prompt)]}
|
{"messages": [("user", new_prompt)]}
|
||||||
)
|
)
|
||||||
|
# Extract output from state based on its structure
|
||||||
if "structured_response" in new_state:
|
if "structured_response" in new_state:
|
||||||
new_output = new_state["structured_response"]
|
new_output = new_state["structured_response"]
|
||||||
elif "messages" in new_state and new_state["messages"]:
|
elif "messages" in new_state and new_state["messages"]:
|
||||||
last_message = new_state["messages"][-1]
|
last_message = new_state["messages"][-1]
|
||||||
if isinstance(last_message, tuple):
|
new_output = self._extract_text(last_message)
|
||||||
new_output = last_message[1]
|
elif "output" in new_state:
|
||||||
else:
|
new_output = str(new_state["output"])
|
||||||
new_output = self._extract_text(last_message)
|
|
||||||
else:
|
else:
|
||||||
new_output = ""
|
# Fallback to extracting text from the entire state
|
||||||
|
new_output = self._extract_text(new_state)
|
||||||
|
|
||||||
|
# Estimate completion tokens for tracking
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters
|
||||||
|
estimated_completion_tokens = (
|
||||||
|
len(new_output) // 4
|
||||||
|
) # ~4 chars per token
|
||||||
|
self.token_process.sum_completion_tokens(
|
||||||
|
estimated_completion_tokens
|
||||||
|
)
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
current_output = new_output
|
current_output = new_output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error during re-invocation with feedback:", e)
|
print("Error during re-invocation with feedback:", e)
|
||||||
@@ -310,6 +408,52 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
agent_role = getattr(self, "role", "agent")
|
agent_role = getattr(self, "role", "agent")
|
||||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||||
|
|
||||||
|
# Initialize token tracking callback if needed
|
||||||
|
if hasattr(self, "token_process") and self.token_callback is None:
|
||||||
|
# Determine if we're using LangChain or LiteLLM based on the agent type
|
||||||
|
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||||
|
self.langchain_agent.client, "callbacks"
|
||||||
|
):
|
||||||
|
# This is likely a LiteLLM-based agent
|
||||||
|
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add our callback to the LLM directly
|
||||||
|
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||||
|
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||||
|
self.langchain_agent.client.callbacks.append(
|
||||||
|
self.token_callback
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||||
|
else:
|
||||||
|
# This is likely a LangChain-based agent
|
||||||
|
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add callback to the LangChain model
|
||||||
|
if hasattr(self.langchain_agent, "callbacks"):
|
||||||
|
if self.langchain_agent.callbacks is None:
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
# For direct LLM models
|
||||||
|
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||||
|
self.langchain_agent.llm, "callbacks"
|
||||||
|
):
|
||||||
|
if self.langchain_agent.llm.callbacks is None:
|
||||||
|
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||||
|
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||||
|
# Direct LLM case
|
||||||
|
elif not hasattr(self.langchain_agent, "agent"):
|
||||||
|
# This might be a direct LLM, not an agent
|
||||||
|
if (
|
||||||
|
not hasattr(self.langchain_agent, "callbacks")
|
||||||
|
or self.langchain_agent.callbacks is None
|
||||||
|
):
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
|
||||||
self.agent_executor = create_react_agent(
|
self.agent_executor = create_react_agent(
|
||||||
model=self.langchain_agent,
|
model=self.langchain_agent,
|
||||||
tools=used_tools,
|
tools=used_tools,
|
||||||
|
|||||||
@@ -641,7 +641,7 @@ class Crew(BaseModel):
|
|||||||
for after_callback in self.after_kickoff_callbacks:
|
for after_callback in self.after_kickoff_callbacks:
|
||||||
result = after_callback(result)
|
result = after_callback(result)
|
||||||
|
|
||||||
metrics += [agent._token_process.get_summary() for agent in self.agents]
|
metrics += [agent.token_process.get_summary() for agent in self.agents]
|
||||||
|
|
||||||
self.usage_metrics = UsageMetrics()
|
self.usage_metrics = UsageMetrics()
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
@@ -1195,12 +1195,15 @@ class Crew(BaseModel):
|
|||||||
"""Calculates and returns the usage metrics."""
|
"""Calculates and returns the usage metrics."""
|
||||||
total_usage_metrics = UsageMetrics()
|
total_usage_metrics = UsageMetrics()
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
if hasattr(agent, "_token_process"):
|
# Directly access token_process since it's now a field in BaseAgent
|
||||||
token_sum = agent._token_process.get_summary()
|
token_sum = agent.token_process.get_summary()
|
||||||
total_usage_metrics.add_usage_metrics(token_sum)
|
|
||||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
|
||||||
token_sum = self.manager_agent._token_process.get_summary()
|
|
||||||
total_usage_metrics.add_usage_metrics(token_sum)
|
total_usage_metrics.add_usage_metrics(token_sum)
|
||||||
|
|
||||||
|
if self.manager_agent:
|
||||||
|
# Directly access token_process since it's now a field in BaseAgent
|
||||||
|
token_sum = self.manager_agent.token_process.get_summary()
|
||||||
|
total_usage_metrics.add_usage_metrics(token_sum)
|
||||||
|
|
||||||
self.usage_metrics = total_usage_metrics
|
self.usage_metrics = total_usage_metrics
|
||||||
return total_usage_metrics
|
return total_usage_metrics
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,52 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Optional
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.utils import Usage
|
from litellm.types.utils import Usage
|
||||||
|
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
|
|
||||||
|
|
||||||
class TokenCalcHandler(CustomLogger):
|
class AbstractTokenCounter(ABC):
|
||||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
"""
|
||||||
self.token_cost_process = token_cost_process
|
Abstract base class for token counting callbacks.
|
||||||
|
Implementations should track token usage from different LLM providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||||
|
"""Initialize with a TokenProcess instance to track tokens."""
|
||||||
|
self.token_process = token_process
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||||
|
"""Update token usage counts in the token process."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||||
|
"""
|
||||||
|
Token counter implementation for LiteLLM.
|
||||||
|
Uses LiteLLM's CustomLogger interface to track token usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||||
|
AbstractTokenCounter.__init__(self, token_process)
|
||||||
|
CustomLogger.__init__(self)
|
||||||
|
|
||||||
|
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||||
|
"""Update token usage counts in the token process."""
|
||||||
|
if self.token_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if prompt_tokens > 0:
|
||||||
|
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||||
|
|
||||||
|
if completion_tokens > 0:
|
||||||
|
self.token_process.sum_completion_tokens(completion_tokens)
|
||||||
|
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
def log_success_event(
|
def log_success_event(
|
||||||
self,
|
self,
|
||||||
@@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger):
|
|||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.token_cost_process is None:
|
"""
|
||||||
|
Process successful LLM call and extract token usage information.
|
||||||
|
This method is called by LiteLLM after a successful completion.
|
||||||
|
"""
|
||||||
|
if self.token_process is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -26,18 +67,154 @@ class TokenCalcHandler(CustomLogger):
|
|||||||
if isinstance(response_obj, dict) and "usage" in response_obj:
|
if isinstance(response_obj, dict) and "usage" in response_obj:
|
||||||
usage: Usage = response_obj["usage"]
|
usage: Usage = response_obj["usage"]
|
||||||
if usage:
|
if usage:
|
||||||
self.token_cost_process.sum_successful_requests(1)
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
if hasattr(usage, "prompt_tokens"):
|
if hasattr(usage, "prompt_tokens"):
|
||||||
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
prompt_tokens = usage.prompt_tokens
|
||||||
|
elif isinstance(usage, dict) and "prompt_tokens" in usage:
|
||||||
|
prompt_tokens = usage["prompt_tokens"]
|
||||||
|
|
||||||
if hasattr(usage, "completion_tokens"):
|
if hasattr(usage, "completion_tokens"):
|
||||||
self.token_cost_process.sum_completion_tokens(
|
completion_tokens = usage.completion_tokens
|
||||||
usage.completion_tokens
|
elif isinstance(usage, dict) and "completion_tokens" in usage:
|
||||||
)
|
completion_tokens = usage["completion_tokens"]
|
||||||
|
|
||||||
|
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# Handle cached tokens if available
|
||||||
if (
|
if (
|
||||||
hasattr(usage, "prompt_tokens_details")
|
hasattr(usage, "prompt_tokens_details")
|
||||||
and usage.prompt_tokens_details
|
and usage.prompt_tokens_details
|
||||||
and usage.prompt_tokens_details.cached_tokens
|
and usage.prompt_tokens_details.cached_tokens
|
||||||
):
|
):
|
||||||
self.token_cost_process.sum_cached_prompt_tokens(
|
self.token_process.sum_cached_prompt_tokens(
|
||||||
usage.prompt_tokens_details.cached_tokens
|
usage.prompt_tokens_details.cached_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
|
||||||
|
"""
|
||||||
|
Token counter implementation for LangChain.
|
||||||
|
Implements the necessary callback methods to track token usage from LangChain responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||||
|
BaseCallbackHandler.__init__(self)
|
||||||
|
AbstractTokenCounter.__init__(self, token_process)
|
||||||
|
|
||||||
|
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||||
|
"""Update token usage counts in the token process."""
|
||||||
|
if self.token_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if prompt_tokens > 0:
|
||||||
|
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||||
|
|
||||||
|
if completion_tokens > 0:
|
||||||
|
self.token_process.sum_completion_tokens(completion_tokens)
|
||||||
|
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_llm(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chain(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_agent(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chat_model(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_retriever(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_tools(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Called when LLM starts processing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Called when LLM generates a new token."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Called when LLM ends processing.
|
||||||
|
Extracts token usage from LangChain response objects.
|
||||||
|
"""
|
||||||
|
if self.token_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle LangChain response format
|
||||||
|
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
|
||||||
|
token_usage = response.llm_output.get("token_usage", {})
|
||||||
|
|
||||||
|
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when LLM errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Called when a chain starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Called when a chain ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when a chain errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Called when a tool starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Called when a tool ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when a tool errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Called when text is generated."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Called when an agent starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
|
"""Called when an agent ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when an agent errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility
|
||||||
|
TokenCalcHandler = LiteLLMTokenCounter
|
||||||
|
|||||||
183
tests/utilities/test_token_tracking.py
Normal file
183
tests/utilities/test_token_tracking.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Test module for token tracking functionality in CrewAI.
|
||||||
|
This tests both direct LangChain models and LiteLLM integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.tools import Tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from crewai import Crew, Process, Task
|
||||||
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
|
from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter
|
||||||
|
from crewai.utilities.token_counter_callback import (
|
||||||
|
LangChainTokenCounter,
|
||||||
|
LiteLLMTokenCounter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weather(location: str = "San Francisco"):
|
||||||
|
"""Simulates fetching current weather data for a given location."""
|
||||||
|
# In a real implementation, you could replace this with an API call.
|
||||||
|
return f"Current weather in {location}: Sunny, 25°C"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenTracking:
|
||||||
|
"""Test suite for token tracking functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def weather_tool(self):
|
||||||
|
"""Create a simple weather tool for testing."""
|
||||||
|
return Tool(
|
||||||
|
name="Weather",
|
||||||
|
func=get_weather,
|
||||||
|
description="Useful for fetching current weather information for a given location.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_response(self):
|
||||||
|
"""Create a mock OpenAI response with token usage information."""
|
||||||
|
return {
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"total_tokens": 150,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_token_process_basic(self):
|
||||||
|
"""Test basic functionality of TokenProcess class."""
|
||||||
|
token_process = TokenProcess()
|
||||||
|
|
||||||
|
# Test adding prompt tokens
|
||||||
|
token_process.sum_prompt_tokens(100)
|
||||||
|
assert token_process.prompt_tokens == 100
|
||||||
|
|
||||||
|
# Test adding completion tokens
|
||||||
|
token_process.sum_completion_tokens(50)
|
||||||
|
assert token_process.completion_tokens == 50
|
||||||
|
|
||||||
|
# Test adding successful requests
|
||||||
|
token_process.sum_successful_requests(1)
|
||||||
|
assert token_process.successful_requests == 1
|
||||||
|
|
||||||
|
# Test getting summary
|
||||||
|
summary = token_process.get_summary()
|
||||||
|
assert summary.prompt_tokens == 100
|
||||||
|
assert summary.completion_tokens == 50
|
||||||
|
assert summary.total_tokens == 150
|
||||||
|
assert summary.successful_requests == 1
|
||||||
|
|
||||||
|
@patch("litellm.completion")
|
||||||
|
def test_litellm_token_counter(self, mock_completion):
|
||||||
|
"""Test LiteLLMTokenCounter with a mock response."""
|
||||||
|
# Setup
|
||||||
|
token_process = TokenProcess()
|
||||||
|
counter = LiteLLMTokenCounter(token_process)
|
||||||
|
|
||||||
|
# Mock the response
|
||||||
|
mock_completion.return_value = {
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simulate a successful LLM call
|
||||||
|
counter.log_success_event(
|
||||||
|
kwargs={},
|
||||||
|
response_obj=mock_completion.return_value,
|
||||||
|
start_time=0,
|
||||||
|
end_time=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify token counts were updated
|
||||||
|
assert token_process.prompt_tokens == 100
|
||||||
|
assert token_process.completion_tokens == 50
|
||||||
|
assert token_process.successful_requests == 1
|
||||||
|
|
||||||
|
def test_langchain_token_counter(self):
|
||||||
|
"""Test LangChainTokenCounter with a mock response."""
|
||||||
|
# Setup
|
||||||
|
token_process = TokenProcess()
|
||||||
|
counter = LangChainTokenCounter(token_process)
|
||||||
|
|
||||||
|
# Create a mock LangChain response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.llm_output = {
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simulate a successful LLM call
|
||||||
|
counter.on_llm_end(mock_response)
|
||||||
|
|
||||||
|
# Verify token counts were updated
|
||||||
|
assert token_process.prompt_tokens == 100
|
||||||
|
assert token_process.completion_tokens == 50
|
||||||
|
assert token_process.successful_requests == 1
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("OPENAI_API_KEY"),
|
||||||
|
reason="OPENAI_API_KEY environment variable not set",
|
||||||
|
)
|
||||||
|
def test_langchain_agent_adapter_token_tracking(self, weather_tool):
|
||||||
|
"""
|
||||||
|
Integration test for token tracking with LangChainAgentAdapter.
|
||||||
|
This test requires an OpenAI API key.
|
||||||
|
"""
|
||||||
|
# Initialize a ChatOpenAI model
|
||||||
|
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||||
|
|
||||||
|
# Create a LangChainAgentAdapter with the direct LLM
|
||||||
|
agent = LangChainAgentAdapter(
|
||||||
|
langchain_agent=llm,
|
||||||
|
tools=[weather_tool],
|
||||||
|
role="Weather Agent",
|
||||||
|
goal="Provide current weather information for the requested location.",
|
||||||
|
backstory="An expert weather provider that fetches current weather information using simulated data.",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a weather task for the agent
|
||||||
|
task = Task(
|
||||||
|
description="Fetch the current weather for San Francisco.",
|
||||||
|
expected_output="A weather report showing current conditions in San Francisco.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a crew with the single agent and task
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=True,
|
||||||
|
process=Process.sequential,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the crew
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Verify token usage was tracked
|
||||||
|
assert result.token_usage is not None
|
||||||
|
assert result.token_usage.total_tokens > 0
|
||||||
|
assert result.token_usage.prompt_tokens > 0
|
||||||
|
assert result.token_usage.completion_tokens > 0
|
||||||
|
assert result.token_usage.successful_requests > 0
|
||||||
|
|
||||||
|
# Also verify token usage directly from the agent
|
||||||
|
usage = agent.token_process.get_summary()
|
||||||
|
assert usage.prompt_tokens > 0
|
||||||
|
assert usage.completion_tokens > 0
|
||||||
|
assert usage.total_tokens > 0
|
||||||
|
assert usage.successful_requests > 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-xvs", __file__])
|
||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
Reference in New Issue
Block a user