mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Fix token tracking in LangChainAgentAdapter and refactor token_process attribute to be public
This commit is contained in:
@@ -82,13 +82,18 @@ class BaseAgent(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
formatting_errors: int = Field(
|
||||
default=0, description="Number of formatting errors."
|
||||
@@ -198,8 +203,6 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
return self
|
||||
|
||||
@@ -219,8 +222,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -268,7 +270,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
"_logger",
|
||||
"_rpm_controller",
|
||||
"_request_within_rpm_limit",
|
||||
"_token_process",
|
||||
"token_process",
|
||||
"agent_executor",
|
||||
"tools",
|
||||
"tools_handler",
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from crewai.tools.base_tool import Tool
|
||||
|
||||
try:
|
||||
from langchain_core.tools import Tool as LangChainTool # type: ignore
|
||||
except ImportError:
|
||||
LangChainTool = None
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.token_counter_callback import (
|
||||
LangChainTokenCounter,
|
||||
LiteLLMTokenCounter,
|
||||
)
|
||||
|
||||
|
||||
class LangChainAgentAdapter(BaseAgent):
|
||||
@@ -51,6 +50,8 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
i18n: Any = None
|
||||
crew: Any = None
|
||||
knowledge: Any = None
|
||||
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||
token_callback: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -72,16 +73,35 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
def _extract_text(self, message: Any) -> str:
|
||||
"""
|
||||
Helper to extract plain text from a message object.
|
||||
This checks if the message is a dict with a "content" key, or has a "content" attribute.
|
||||
This checks if the message is a dict with a "content" key, or has a "content" attribute,
|
||||
or if it's a tuple from LangGraph's message format.
|
||||
"""
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
return message["content"]
|
||||
# Handle LangGraph message tuple format (role, content)
|
||||
if isinstance(message, tuple) and len(message) == 2:
|
||||
return str(message[1])
|
||||
|
||||
# Handle dictionary with content key
|
||||
elif isinstance(message, dict):
|
||||
if "content" in message:
|
||||
return message["content"]
|
||||
# Handle LangGraph message format with additional metadata
|
||||
elif "messages" in message and message["messages"]:
|
||||
last_message = message["messages"][-1]
|
||||
if isinstance(last_message, tuple) and len(last_message) == 2:
|
||||
return str(last_message[1])
|
||||
return self._extract_text(last_message)
|
||||
|
||||
# Handle object with content attribute
|
||||
elif hasattr(message, "content") and isinstance(
|
||||
getattr(message, "content"), str
|
||||
):
|
||||
return getattr(message, "content")
|
||||
|
||||
# Handle string directly
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
|
||||
# Default fallback
|
||||
return str(message)
|
||||
|
||||
def execute_task(
|
||||
@@ -161,19 +181,77 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Initialize token tracking callback if needed
|
||||
if hasattr(self, "token_process") and self.token_callback is None:
|
||||
# Determine if we're using LangChain or LiteLLM based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
|
||||
init_state = {"messages": [("user", task_prompt)]}
|
||||
|
||||
# Estimate input tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters (better than word count)
|
||||
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
|
||||
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||
|
||||
state = self.agent_executor.invoke(init_state)
|
||||
|
||||
# Extract output from state based on its structure
|
||||
if "structured_response" in state:
|
||||
current_output = state["structured_response"]
|
||||
elif "messages" in state and state["messages"]:
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, tuple):
|
||||
current_output = last_message[1]
|
||||
else:
|
||||
current_output = self._extract_text(last_message)
|
||||
current_output = self._extract_text(last_message)
|
||||
elif "output" in state:
|
||||
current_output = str(state["output"])
|
||||
else:
|
||||
current_output = ""
|
||||
# Fallback to extracting text from the entire state
|
||||
current_output = self._extract_text(state)
|
||||
|
||||
# Estimate completion tokens for tracking if we don't have actual counts
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
|
||||
self.token_process.sum_completion_tokens(estimated_completion_tokens)
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
if task.human_input:
|
||||
current_output = self._handle_human_feedback(current_output)
|
||||
@@ -203,20 +281,40 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
||||
f"Updated answer:"
|
||||
)
|
||||
|
||||
# Estimate input tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
|
||||
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||
|
||||
try:
|
||||
new_state = self.agent_executor.invoke(
|
||||
{"messages": [("user", new_prompt)]}
|
||||
)
|
||||
# Extract output from state based on its structure
|
||||
if "structured_response" in new_state:
|
||||
new_output = new_state["structured_response"]
|
||||
elif "messages" in new_state and new_state["messages"]:
|
||||
last_message = new_state["messages"][-1]
|
||||
if isinstance(last_message, tuple):
|
||||
new_output = last_message[1]
|
||||
else:
|
||||
new_output = self._extract_text(last_message)
|
||||
new_output = self._extract_text(last_message)
|
||||
elif "output" in new_state:
|
||||
new_output = str(new_state["output"])
|
||||
else:
|
||||
new_output = ""
|
||||
# Fallback to extracting text from the entire state
|
||||
new_output = self._extract_text(new_state)
|
||||
|
||||
# Estimate completion tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_completion_tokens = (
|
||||
len(new_output) // 4
|
||||
) # ~4 chars per token
|
||||
self.token_process.sum_completion_tokens(
|
||||
estimated_completion_tokens
|
||||
)
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
current_output = new_output
|
||||
except Exception as e:
|
||||
print("Error during re-invocation with feedback:", e)
|
||||
@@ -310,6 +408,52 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
agent_role = getattr(self, "role", "agent")
|
||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||
|
||||
# Initialize token tracking callback if needed
|
||||
if hasattr(self, "token_process") and self.token_callback is None:
|
||||
# Determine if we're using LangChain or LiteLLM based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||
self.langchain_agent.client.callbacks.append(
|
||||
self.token_callback
|
||||
)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
|
||||
self.agent_executor = create_react_agent(
|
||||
model=self.langchain_agent,
|
||||
tools=used_tools,
|
||||
|
||||
@@ -641,7 +641,7 @@ class Crew(BaseModel):
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
metrics += [agent._token_process.get_summary() for agent in self.agents]
|
||||
metrics += [agent.token_process.get_summary() for agent in self.agents]
|
||||
|
||||
self.usage_metrics = UsageMetrics()
|
||||
for metric in metrics:
|
||||
@@ -1195,12 +1195,15 @@ class Crew(BaseModel):
|
||||
"""Calculates and returns the usage metrics."""
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for agent in self.agents:
|
||||
if hasattr(agent, "_token_process"):
|
||||
token_sum = agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
||||
token_sum = self.manager_agent._token_process.get_summary()
|
||||
# Directly access token_process since it's now a field in BaseAgent
|
||||
token_sum = agent.token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
if self.manager_agent:
|
||||
# Directly access token_process since it's now a field in BaseAgent
|
||||
token_sum = self.manager_agent.token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
self.usage_metrics = total_usage_metrics
|
||||
return total_usage_metrics
|
||||
|
||||
|
||||
@@ -1,15 +1,52 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
|
||||
|
||||
class TokenCalcHandler(CustomLogger):
|
||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||
self.token_cost_process = token_cost_process
|
||||
class AbstractTokenCounter(ABC):
|
||||
"""
|
||||
Abstract base class for token counting callbacks.
|
||||
Implementations should track token usage from different LLM providers.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
"""Initialize with a TokenProcess instance to track tokens."""
|
||||
self.token_process = token_process
|
||||
|
||||
@abstractmethod
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
pass
|
||||
|
||||
|
||||
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||
"""
|
||||
Token counter implementation for LiteLLM.
|
||||
Uses LiteLLM's CustomLogger interface to track token usage.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
AbstractTokenCounter.__init__(self, token_process)
|
||||
CustomLogger.__init__(self)
|
||||
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
if prompt_tokens > 0:
|
||||
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
self.token_process.sum_completion_tokens(completion_tokens)
|
||||
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
def log_success_event(
|
||||
self,
|
||||
@@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger):
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> None:
|
||||
if self.token_cost_process is None:
|
||||
"""
|
||||
Process successful LLM call and extract token usage information.
|
||||
This method is called by LiteLLM after a successful completion.
|
||||
"""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -26,18 +67,154 @@ class TokenCalcHandler(CustomLogger):
|
||||
if isinstance(response_obj, dict) and "usage" in response_obj:
|
||||
usage: Usage = response_obj["usage"]
|
||||
if usage:
|
||||
self.token_cost_process.sum_successful_requests(1)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
|
||||
if hasattr(usage, "prompt_tokens"):
|
||||
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
elif isinstance(usage, dict) and "prompt_tokens" in usage:
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
|
||||
if hasattr(usage, "completion_tokens"):
|
||||
self.token_cost_process.sum_completion_tokens(
|
||||
usage.completion_tokens
|
||||
)
|
||||
completion_tokens = usage.completion_tokens
|
||||
elif isinstance(usage, dict) and "completion_tokens" in usage:
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
|
||||
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||
|
||||
# Handle cached tokens if available
|
||||
if (
|
||||
hasattr(usage, "prompt_tokens_details")
|
||||
and usage.prompt_tokens_details
|
||||
and usage.prompt_tokens_details.cached_tokens
|
||||
):
|
||||
self.token_cost_process.sum_cached_prompt_tokens(
|
||||
self.token_process.sum_cached_prompt_tokens(
|
||||
usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
|
||||
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
|
||||
"""
|
||||
Token counter implementation for LangChain.
|
||||
Implements the necessary callback methods to track token usage from LangChain responses.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
BaseCallbackHandler.__init__(self)
|
||||
AbstractTokenCounter.__init__(self, token_process)
|
||||
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
if prompt_tokens > 0:
|
||||
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
self.token_process.sum_completion_tokens(completion_tokens)
|
||||
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_tools(self) -> bool:
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when LLM starts processing."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Called when LLM generates a new token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Called when LLM ends processing.
|
||||
Extracts token usage from LangChain response objects.
|
||||
"""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
# Handle LangChain response format
|
||||
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
|
||||
token_usage = response.llm_output.get("token_usage", {})
|
||||
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
|
||||
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when LLM errors."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when a chain starts."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when a chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when a chain errors."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when a tool starts."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Called when a tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when a tool errors."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Called when text is generated."""
|
||||
pass
|
||||
|
||||
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when an agent starts."""
|
||||
pass
|
||||
|
||||
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Called when an agent ends."""
|
||||
pass
|
||||
|
||||
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when an agent errors."""
|
||||
pass
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
TokenCalcHandler = LiteLLMTokenCounter
|
||||
|
||||
183
tests/utilities/test_token_tracking.py
Normal file
183
tests/utilities/test_token_tracking.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test module for token tracking functionality in CrewAI.
|
||||
This tests both direct LangChain models and LiteLLM integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from crewai import Crew, Process, Task
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter
|
||||
from crewai.utilities.token_counter_callback import (
|
||||
LangChainTokenCounter,
|
||||
LiteLLMTokenCounter,
|
||||
)
|
||||
|
||||
|
||||
def get_weather(location: str = "San Francisco"):
|
||||
"""Simulates fetching current weather data for a given location."""
|
||||
# In a real implementation, you could replace this with an API call.
|
||||
return f"Current weather in {location}: Sunny, 25°C"
|
||||
|
||||
|
||||
class TestTokenTracking:
|
||||
"""Test suite for token tracking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool(self):
|
||||
"""Create a simple weather tool for testing."""
|
||||
return Tool(
|
||||
name="Weather",
|
||||
func=get_weather,
|
||||
description="Useful for fetching current weather information for a given location.",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_response(self):
|
||||
"""Create a mock OpenAI response with token usage information."""
|
||||
return {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
}
|
||||
|
||||
def test_token_process_basic(self):
|
||||
"""Test basic functionality of TokenProcess class."""
|
||||
token_process = TokenProcess()
|
||||
|
||||
# Test adding prompt tokens
|
||||
token_process.sum_prompt_tokens(100)
|
||||
assert token_process.prompt_tokens == 100
|
||||
|
||||
# Test adding completion tokens
|
||||
token_process.sum_completion_tokens(50)
|
||||
assert token_process.completion_tokens == 50
|
||||
|
||||
# Test adding successful requests
|
||||
token_process.sum_successful_requests(1)
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
# Test getting summary
|
||||
summary = token_process.get_summary()
|
||||
assert summary.prompt_tokens == 100
|
||||
assert summary.completion_tokens == 50
|
||||
assert summary.total_tokens == 150
|
||||
assert summary.successful_requests == 1
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_litellm_token_counter(self, mock_completion):
|
||||
"""Test LiteLLMTokenCounter with a mock response."""
|
||||
# Setup
|
||||
token_process = TokenProcess()
|
||||
counter = LiteLLMTokenCounter(token_process)
|
||||
|
||||
# Mock the response
|
||||
mock_completion.return_value = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
}
|
||||
|
||||
# Simulate a successful LLM call
|
||||
counter.log_success_event(
|
||||
kwargs={},
|
||||
response_obj=mock_completion.return_value,
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
# Verify token counts were updated
|
||||
assert token_process.prompt_tokens == 100
|
||||
assert token_process.completion_tokens == 50
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
def test_langchain_token_counter(self):
|
||||
"""Test LangChainTokenCounter with a mock response."""
|
||||
# Setup
|
||||
token_process = TokenProcess()
|
||||
counter = LangChainTokenCounter(token_process)
|
||||
|
||||
# Create a mock LangChain response
|
||||
mock_response = MagicMock()
|
||||
mock_response.llm_output = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
}
|
||||
|
||||
# Simulate a successful LLM call
|
||||
counter.on_llm_end(mock_response)
|
||||
|
||||
# Verify token counts were updated
|
||||
assert token_process.prompt_tokens == 100
|
||||
assert token_process.completion_tokens == 50
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY"),
|
||||
reason="OPENAI_API_KEY environment variable not set",
|
||||
)
|
||||
def test_langchain_agent_adapter_token_tracking(self, weather_tool):
|
||||
"""
|
||||
Integration test for token tracking with LangChainAgentAdapter.
|
||||
This test requires an OpenAI API key.
|
||||
"""
|
||||
# Initialize a ChatOpenAI model
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
|
||||
# Create a LangChainAgentAdapter with the direct LLM
|
||||
agent = LangChainAgentAdapter(
|
||||
langchain_agent=llm,
|
||||
tools=[weather_tool],
|
||||
role="Weather Agent",
|
||||
goal="Provide current weather information for the requested location.",
|
||||
backstory="An expert weather provider that fetches current weather information using simulated data.",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create a weather task for the agent
|
||||
task = Task(
|
||||
description="Fetch the current weather for San Francisco.",
|
||||
expected_output="A weather report showing current conditions in San Francisco.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew with the single agent and task
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=True,
|
||||
process=Process.sequential,
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify token usage was tracked
|
||||
assert result.token_usage is not None
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests > 0
|
||||
|
||||
# Also verify token usage directly from the agent
|
||||
usage = agent.token_process.get_summary()
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.successful_requests > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
Reference in New Issue
Block a user