From 88d8079dcda27f3a13688b071a8c5f3219f7c99a Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 28 Feb 2025 11:31:08 -0500 Subject: [PATCH] Fix token tracking in LangChainAgentAdapter and refactor token_process attribute to be public --- src/crewai/agents/agent_builder/base_agent.py | 14 +- src/crewai/agents/langchain_agent_adapter.py | 184 ++++++++++++++-- src/crewai/crew.py | 15 +- .../utilities/token_counter_callback.py | 199 +++++++++++++++++- tests/utilities/test_token_tracking.py | 183 ++++++++++++++++ .../crewai/agents/langchain_agent_adapter.py | 1 - 6 files changed, 552 insertions(+), 44 deletions(-) create mode 100644 tests/utilities/test_token_tracking.py delete mode 100644 title=src/crewai/agents/langchain_agent_adapter.py diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 215c66a66..b2fa66f0e 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -82,13 +82,18 @@ class BaseAgent(ABC, BaseModel): """ __hash__ = object.__hash__ # type: ignore + + model_config = { + "arbitrary_types_allowed": True, + } + _logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False)) _rpm_controller: Optional[RPMController] = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None) _original_role: Optional[str] = PrivateAttr(default=None) _original_goal: Optional[str] = PrivateAttr(default=None) _original_backstory: Optional[str] = PrivateAttr(default=None) - _token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess) + token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) formatting_errors: int = Field( default=0, description="Number of formatting errors." @@ -198,8 +203,6 @@ class BaseAgent(ABC, BaseModel): self._rpm_controller = RPMController( max_rpm=self.max_rpm, logger=self._logger ) - if not self._token_process: - self._token_process = TokenProcess() return self @@ -219,8 +222,7 @@ class BaseAgent(ABC, BaseModel): self._rpm_controller = RPMController( max_rpm=self.max_rpm, logger=self._logger ) - if not self._token_process: - self._token_process = TokenProcess() + return self @property @@ -268,7 +270,7 @@ class BaseAgent(ABC, BaseModel): "_logger", "_rpm_controller", "_request_within_rpm_limit", - "_token_process", + "token_process", "agent_executor", "tools", "tools_handler", diff --git a/src/crewai/agents/langchain_agent_adapter.py b/src/crewai/agents/langchain_agent_adapter.py index 941dad38d..7526434cb 100644 --- a/src/crewai/agents/langchain_agent_adapter.py +++ b/src/crewai/agents/langchain_agent_adapter.py @@ -1,18 +1,17 @@ from typing import Any, List, Optional, Type, Union, cast -from crewai.tools.base_tool import Tool - -try: - from langchain_core.tools import Tool as LangChainTool # type: ignore -except ImportError: - LangChainTool = None - from pydantic import Field, field_validator from crewai.agents.agent_builder.base_agent import BaseAgent +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.task import Task from crewai.tools import BaseTool +from crewai.tools.base_tool import Tool from crewai.utilities.converter import Converter, generate_model_description +from crewai.utilities.token_counter_callback import ( + LangChainTokenCounter, + LiteLLMTokenCounter, +) class LangChainAgentAdapter(BaseAgent): @@ -51,6 +50,8 @@ class LangChainAgentAdapter(BaseAgent): i18n: Any = None crew: Any = None knowledge: Any = None + token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True) + token_callback: Optional[Any] = None class Config: arbitrary_types_allowed = True @@ -72,16 +73,35 @@ class LangChainAgentAdapter(BaseAgent): def _extract_text(self, message: Any) -> str: """ Helper to extract plain text from a message object. - This checks if the message is a dict with a "content" key, or has a "content" attribute. + This checks if the message is a dict with a "content" key, or has a "content" attribute, + or if it's a tuple from LangGraph's message format. """ - if isinstance(message, dict) and "content" in message: - return message["content"] + # Handle LangGraph message tuple format (role, content) + if isinstance(message, tuple) and len(message) == 2: + return str(message[1]) + + # Handle dictionary with content key + elif isinstance(message, dict): + if "content" in message: + return message["content"] + # Handle LangGraph message format with additional metadata + elif "messages" in message and message["messages"]: + last_message = message["messages"][-1] + if isinstance(last_message, tuple) and len(last_message) == 2: + return str(last_message[1]) + return self._extract_text(last_message) + + # Handle object with content attribute elif hasattr(message, "content") and isinstance( getattr(message, "content"), str ): return getattr(message, "content") + + # Handle string directly elif isinstance(message, str): return message + + # Default fallback return str(message) def execute_task( @@ -161,19 +181,77 @@ class LangChainAgentAdapter(BaseAgent): else: task_prompt = self._use_trained_data(task_prompt=task_prompt) + # Initialize token tracking callback if needed + if hasattr(self, "token_process") and self.token_callback is None: + # Determine if we're using LangChain or LiteLLM based on the agent type + if hasattr(self.langchain_agent, "client") and hasattr( + self.langchain_agent.client, "callbacks" + ): + # This is likely a LiteLLM-based agent + self.token_callback = LiteLLMTokenCounter(self.token_process) + + # Add our callback to the LLM directly + if isinstance(self.langchain_agent.client.callbacks, list): + self.langchain_agent.client.callbacks.append(self.token_callback) + else: + self.langchain_agent.client.callbacks = [self.token_callback] + else: + # This is likely a LangChain-based agent + self.token_callback = LangChainTokenCounter(self.token_process) + + # Add callback to the LangChain model + if hasattr(self.langchain_agent, "callbacks"): + if self.langchain_agent.callbacks is None: + self.langchain_agent.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.callbacks, list): + self.langchain_agent.callbacks.append(self.token_callback) + # For direct LLM models + elif hasattr(self.langchain_agent, "llm") and hasattr( + self.langchain_agent.llm, "callbacks" + ): + if self.langchain_agent.llm.callbacks is None: + self.langchain_agent.llm.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.llm.callbacks, list): + self.langchain_agent.llm.callbacks.append(self.token_callback) + # Direct LLM case + elif not hasattr(self.langchain_agent, "agent"): + # This might be a direct LLM, not an agent + if ( + not hasattr(self.langchain_agent, "callbacks") + or self.langchain_agent.callbacks is None + ): + self.langchain_agent.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.callbacks, list): + self.langchain_agent.callbacks.append(self.token_callback) + init_state = {"messages": [("user", task_prompt)]} + + # Estimate input tokens for tracking + if hasattr(self, "token_process"): + # Rough estimate based on characters (better than word count) + estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token + self.token_process.sum_prompt_tokens(estimated_prompt_tokens) + state = self.agent_executor.invoke(init_state) + # Extract output from state based on its structure if "structured_response" in state: current_output = state["structured_response"] elif "messages" in state and state["messages"]: last_message = state["messages"][-1] - if isinstance(last_message, tuple): - current_output = last_message[1] - else: - current_output = self._extract_text(last_message) + current_output = self._extract_text(last_message) + elif "output" in state: + current_output = str(state["output"]) else: - current_output = "" + # Fallback to extracting text from the entire state + current_output = self._extract_text(state) + + # Estimate completion tokens for tracking if we don't have actual counts + if hasattr(self, "token_process"): + # Rough estimate based on characters + estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token + self.token_process.sum_completion_tokens(estimated_completion_tokens) + self.token_process.sum_successful_requests(1) if task.human_input: current_output = self._handle_human_feedback(current_output) @@ -203,20 +281,40 @@ class LangChainAgentAdapter(BaseAgent): f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n" f"Updated answer:" ) + + # Estimate input tokens for tracking + if hasattr(self, "token_process"): + # Rough estimate based on characters + estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token + self.token_process.sum_prompt_tokens(estimated_prompt_tokens) + try: new_state = self.agent_executor.invoke( {"messages": [("user", new_prompt)]} ) + # Extract output from state based on its structure if "structured_response" in new_state: new_output = new_state["structured_response"] elif "messages" in new_state and new_state["messages"]: last_message = new_state["messages"][-1] - if isinstance(last_message, tuple): - new_output = last_message[1] - else: - new_output = self._extract_text(last_message) + new_output = self._extract_text(last_message) + elif "output" in new_state: + new_output = str(new_state["output"]) else: - new_output = "" + # Fallback to extracting text from the entire state + new_output = self._extract_text(new_state) + + # Estimate completion tokens for tracking + if hasattr(self, "token_process"): + # Rough estimate based on characters + estimated_completion_tokens = ( + len(new_output) // 4 + ) # ~4 chars per token + self.token_process.sum_completion_tokens( + estimated_completion_tokens + ) + self.token_process.sum_successful_requests(1) + current_output = new_output except Exception as e: print("Error during re-invocation with feedback:", e) @@ -310,6 +408,52 @@ class LangChainAgentAdapter(BaseAgent): agent_role = getattr(self, "role", "agent") sanitized_role = re.sub(r"\s+", "_", agent_role) + # Initialize token tracking callback if needed + if hasattr(self, "token_process") and self.token_callback is None: + # Determine if we're using LangChain or LiteLLM based on the agent type + if hasattr(self.langchain_agent, "client") and hasattr( + self.langchain_agent.client, "callbacks" + ): + # This is likely a LiteLLM-based agent + self.token_callback = LiteLLMTokenCounter(self.token_process) + + # Add our callback to the LLM directly + if isinstance(self.langchain_agent.client.callbacks, list): + if self.token_callback not in self.langchain_agent.client.callbacks: + self.langchain_agent.client.callbacks.append( + self.token_callback + ) + else: + self.langchain_agent.client.callbacks = [self.token_callback] + else: + # This is likely a LangChain-based agent + self.token_callback = LangChainTokenCounter(self.token_process) + + # Add callback to the LangChain model + if hasattr(self.langchain_agent, "callbacks"): + if self.langchain_agent.callbacks is None: + self.langchain_agent.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.callbacks, list): + self.langchain_agent.callbacks.append(self.token_callback) + # For direct LLM models + elif hasattr(self.langchain_agent, "llm") and hasattr( + self.langchain_agent.llm, "callbacks" + ): + if self.langchain_agent.llm.callbacks is None: + self.langchain_agent.llm.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.llm.callbacks, list): + self.langchain_agent.llm.callbacks.append(self.token_callback) + # Direct LLM case + elif not hasattr(self.langchain_agent, "agent"): + # This might be a direct LLM, not an agent + if ( + not hasattr(self.langchain_agent, "callbacks") + or self.langchain_agent.callbacks is None + ): + self.langchain_agent.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.callbacks, list): + self.langchain_agent.callbacks.append(self.token_callback) + self.agent_executor = create_react_agent( model=self.langchain_agent, tools=used_tools, diff --git a/src/crewai/crew.py b/src/crewai/crew.py index ecf8c83de..c3f1c1613 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -641,7 +641,7 @@ class Crew(BaseModel): for after_callback in self.after_kickoff_callbacks: result = after_callback(result) - metrics += [agent._token_process.get_summary() for agent in self.agents] + metrics += [agent.token_process.get_summary() for agent in self.agents] self.usage_metrics = UsageMetrics() for metric in metrics: @@ -1195,12 +1195,15 @@ class Crew(BaseModel): """Calculates and returns the usage metrics.""" total_usage_metrics = UsageMetrics() for agent in self.agents: - if hasattr(agent, "_token_process"): - token_sum = agent._token_process.get_summary() - total_usage_metrics.add_usage_metrics(token_sum) - if self.manager_agent and hasattr(self.manager_agent, "_token_process"): - token_sum = self.manager_agent._token_process.get_summary() + # Directly access token_process since it's now a field in BaseAgent + token_sum = agent.token_process.get_summary() total_usage_metrics.add_usage_metrics(token_sum) + + if self.manager_agent: + # Directly access token_process since it's now a field in BaseAgent + token_sum = self.manager_agent.token_process.get_summary() + total_usage_metrics.add_usage_metrics(token_sum) + self.usage_metrics = total_usage_metrics return total_usage_metrics diff --git a/src/crewai/utilities/token_counter_callback.py b/src/crewai/utilities/token_counter_callback.py index 7037ad5c4..9b9e07bdb 100644 --- a/src/crewai/utilities/token_counter_callback.py +++ b/src/crewai/utilities/token_counter_callback.py @@ -1,15 +1,52 @@ import warnings -from typing import Any, Dict, Optional +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union +from langchain_core.callbacks.base import BaseCallbackHandler from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import Usage from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess -class TokenCalcHandler(CustomLogger): - def __init__(self, token_cost_process: Optional[TokenProcess]): - self.token_cost_process = token_cost_process +class AbstractTokenCounter(ABC): + """ + Abstract base class for token counting callbacks. + Implementations should track token usage from different LLM providers. + """ + + def __init__(self, token_process: Optional[TokenProcess] = None): + """Initialize with a TokenProcess instance to track tokens.""" + self.token_process = token_process + + @abstractmethod + def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None: + """Update token usage counts in the token process.""" + pass + + +class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter): + """ + Token counter implementation for LiteLLM. + Uses LiteLLM's CustomLogger interface to track token usage. + """ + + def __init__(self, token_process: Optional[TokenProcess] = None): + AbstractTokenCounter.__init__(self, token_process) + CustomLogger.__init__(self) + + def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None: + """Update token usage counts in the token process.""" + if self.token_process is None: + return + + if prompt_tokens > 0: + self.token_process.sum_prompt_tokens(prompt_tokens) + + if completion_tokens > 0: + self.token_process.sum_completion_tokens(completion_tokens) + + self.token_process.sum_successful_requests(1) def log_success_event( self, @@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger): start_time: float, end_time: float, ) -> None: - if self.token_cost_process is None: + """ + Process successful LLM call and extract token usage information. + This method is called by LiteLLM after a successful completion. + """ + if self.token_process is None: return with warnings.catch_warnings(): @@ -26,18 +67,154 @@ class TokenCalcHandler(CustomLogger): if isinstance(response_obj, dict) and "usage" in response_obj: usage: Usage = response_obj["usage"] if usage: - self.token_cost_process.sum_successful_requests(1) + prompt_tokens = 0 + completion_tokens = 0 + if hasattr(usage, "prompt_tokens"): - self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens) + prompt_tokens = usage.prompt_tokens + elif isinstance(usage, dict) and "prompt_tokens" in usage: + prompt_tokens = usage["prompt_tokens"] + if hasattr(usage, "completion_tokens"): - self.token_cost_process.sum_completion_tokens( - usage.completion_tokens - ) + completion_tokens = usage.completion_tokens + elif isinstance(usage, dict) and "completion_tokens" in usage: + completion_tokens = usage["completion_tokens"] + + self.update_token_usage(prompt_tokens, completion_tokens) + + # Handle cached tokens if available if ( hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens ): - self.token_cost_process.sum_cached_prompt_tokens( + self.token_process.sum_cached_prompt_tokens( usage.prompt_tokens_details.cached_tokens ) + + +class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter): + """ + Token counter implementation for LangChain. + Implements the necessary callback methods to track token usage from LangChain responses. + """ + + def __init__(self, token_process: Optional[TokenProcess] = None): + BaseCallbackHandler.__init__(self) + AbstractTokenCounter.__init__(self, token_process) + + def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None: + """Update token usage counts in the token process.""" + if self.token_process is None: + return + + if prompt_tokens > 0: + self.token_process.sum_prompt_tokens(prompt_tokens) + + if completion_tokens > 0: + self.token_process.sum_completion_tokens(completion_tokens) + + self.token_process.sum_successful_requests(1) + + @property + def ignore_llm(self) -> bool: + return False + + @property + def ignore_chain(self) -> bool: + return True + + @property + def ignore_agent(self) -> bool: + return False + + @property + def ignore_chat_model(self) -> bool: + return False + + @property + def ignore_retriever(self) -> bool: + return True + + @property + def ignore_tools(self) -> bool: + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Called when LLM starts processing.""" + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Called when LLM generates a new token.""" + pass + + def on_llm_end(self, response: Any, **kwargs: Any) -> None: + """ + Called when LLM ends processing. + Extracts token usage from LangChain response objects. + """ + if self.token_process is None: + return + + # Handle LangChain response format + if hasattr(response, "llm_output") and isinstance(response.llm_output, dict): + token_usage = response.llm_output.get("token_usage", {}) + + prompt_tokens = token_usage.get("prompt_tokens", 0) + completion_tokens = token_usage.get("completion_tokens", 0) + + self.update_token_usage(prompt_tokens, completion_tokens) + + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: + """Called when LLM errors.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Called when a chain starts.""" + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Called when a chain ends.""" + pass + + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: + """Called when a chain errors.""" + pass + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Called when a tool starts.""" + pass + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Called when a tool ends.""" + pass + + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: + """Called when a tool errors.""" + pass + + def on_text(self, text: str, **kwargs: Any) -> None: + """Called when text is generated.""" + pass + + def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None: + """Called when an agent starts.""" + pass + + def on_agent_end(self, output: Any, **kwargs: Any) -> None: + """Called when an agent ends.""" + pass + + def on_agent_error(self, error: BaseException, **kwargs: Any) -> None: + """Called when an agent errors.""" + pass + + +# For backward compatibility +TokenCalcHandler = LiteLLMTokenCounter diff --git a/tests/utilities/test_token_tracking.py b/tests/utilities/test_token_tracking.py new file mode 100644 index 000000000..53a4455ad --- /dev/null +++ b/tests/utilities/test_token_tracking.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python +""" +Test module for token tracking functionality in CrewAI. +This tests both direct LangChain models and LiteLLM integration. +""" + +import os +from typing import Any, Dict +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.tools import Tool +from langchain_openai import ChatOpenAI + +from crewai import Crew, Process, Task +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess +from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter +from crewai.utilities.token_counter_callback import ( + LangChainTokenCounter, + LiteLLMTokenCounter, +) + + +def get_weather(location: str = "San Francisco"): + """Simulates fetching current weather data for a given location.""" + # In a real implementation, you could replace this with an API call. + return f"Current weather in {location}: Sunny, 25°C" + + +class TestTokenTracking: + """Test suite for token tracking functionality.""" + + @pytest.fixture + def weather_tool(self): + """Create a simple weather tool for testing.""" + return Tool( + name="Weather", + func=get_weather, + description="Useful for fetching current weather information for a given location.", + ) + + @pytest.fixture + def mock_openai_response(self): + """Create a mock OpenAI response with token usage information.""" + return { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + } + } + + def test_token_process_basic(self): + """Test basic functionality of TokenProcess class.""" + token_process = TokenProcess() + + # Test adding prompt tokens + token_process.sum_prompt_tokens(100) + assert token_process.prompt_tokens == 100 + + # Test adding completion tokens + token_process.sum_completion_tokens(50) + assert token_process.completion_tokens == 50 + + # Test adding successful requests + token_process.sum_successful_requests(1) + assert token_process.successful_requests == 1 + + # Test getting summary + summary = token_process.get_summary() + assert summary.prompt_tokens == 100 + assert summary.completion_tokens == 50 + assert summary.total_tokens == 150 + assert summary.successful_requests == 1 + + @patch("litellm.completion") + def test_litellm_token_counter(self, mock_completion): + """Test LiteLLMTokenCounter with a mock response.""" + # Setup + token_process = TokenProcess() + counter = LiteLLMTokenCounter(token_process) + + # Mock the response + mock_completion.return_value = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + } + } + + # Simulate a successful LLM call + counter.log_success_event( + kwargs={}, + response_obj=mock_completion.return_value, + start_time=0, + end_time=1, + ) + + # Verify token counts were updated + assert token_process.prompt_tokens == 100 + assert token_process.completion_tokens == 50 + assert token_process.successful_requests == 1 + + def test_langchain_token_counter(self): + """Test LangChainTokenCounter with a mock response.""" + # Setup + token_process = TokenProcess() + counter = LangChainTokenCounter(token_process) + + # Create a mock LangChain response + mock_response = MagicMock() + mock_response.llm_output = { + "token_usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + } + } + + # Simulate a successful LLM call + counter.on_llm_end(mock_response) + + # Verify token counts were updated + assert token_process.prompt_tokens == 100 + assert token_process.completion_tokens == 50 + assert token_process.successful_requests == 1 + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="OPENAI_API_KEY environment variable not set", + ) + def test_langchain_agent_adapter_token_tracking(self, weather_tool): + """ + Integration test for token tracking with LangChainAgentAdapter. + This test requires an OpenAI API key. + """ + # Initialize a ChatOpenAI model + llm = ChatOpenAI(model="gpt-3.5-turbo") + + # Create a LangChainAgentAdapter with the direct LLM + agent = LangChainAgentAdapter( + langchain_agent=llm, + tools=[weather_tool], + role="Weather Agent", + goal="Provide current weather information for the requested location.", + backstory="An expert weather provider that fetches current weather information using simulated data.", + verbose=True, + ) + + # Create a weather task for the agent + task = Task( + description="Fetch the current weather for San Francisco.", + expected_output="A weather report showing current conditions in San Francisco.", + agent=agent, + ) + + # Create a crew with the single agent and task + crew = Crew( + agents=[agent], + tasks=[task], + verbose=True, + process=Process.sequential, + ) + + # Execute the crew + result = crew.kickoff() + + # Verify token usage was tracked + assert result.token_usage is not None + assert result.token_usage.total_tokens > 0 + assert result.token_usage.prompt_tokens > 0 + assert result.token_usage.completion_tokens > 0 + assert result.token_usage.successful_requests > 0 + + # Also verify token usage directly from the agent + usage = agent.token_process.get_summary() + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens > 0 + assert usage.successful_requests > 0 + + +if __name__ == "__main__": + pytest.main(["-xvs", __file__]) diff --git a/title=src/crewai/agents/langchain_agent_adapter.py b/title=src/crewai/agents/langchain_agent_adapter.py deleted file mode 100644 index 0519ecba6..000000000 --- a/title=src/crewai/agents/langchain_agent_adapter.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file