From 6efee89399e7e418c0f2507fc0970492705cf5b2 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 28 Feb 2025 12:30:30 -0500 Subject: [PATCH] refactor --- src/crewai/agents/langchain_agent_adapter.py | 147 ++++++++----------- 1 file changed, 60 insertions(+), 87 deletions(-) diff --git a/src/crewai/agents/langchain_agent_adapter.py b/src/crewai/agents/langchain_agent_adapter.py index 7526434cb..5f6061545 100644 --- a/src/crewai/agents/langchain_agent_adapter.py +++ b/src/crewai/agents/langchain_agent_adapter.py @@ -104,6 +104,62 @@ class LangChainAgentAdapter(BaseAgent): # Default fallback return str(message) + def _register_token_callback(self): + """ + Register the appropriate token counter callback with the language model. + This method handles different types of models (LiteLLM, LangChain, direct LLMs) + and different callback structures. + """ + # Skip if we already have a token callback registered + if self.token_callback is not None: + return + + # Skip if we don't have a token_process attribute + if not hasattr(self, "token_process"): + return + + # Determine if we're using LiteLLM or LangChain based on the agent type + if hasattr(self.langchain_agent, "client") and hasattr( + self.langchain_agent.client, "callbacks" + ): + # This is likely a LiteLLM-based agent + self.token_callback = LiteLLMTokenCounter(self.token_process) + + # Add our callback to the LLM directly + if isinstance(self.langchain_agent.client.callbacks, list): + if self.token_callback not in self.langchain_agent.client.callbacks: + self.langchain_agent.client.callbacks.append(self.token_callback) + else: + self.langchain_agent.client.callbacks = [self.token_callback] + else: + # This is likely a LangChain-based agent + self.token_callback = LangChainTokenCounter(self.token_process) + + # Add callback to the LangChain model + if hasattr(self.langchain_agent, "callbacks"): + if self.langchain_agent.callbacks is None: + self.langchain_agent.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.callbacks, list): + self.langchain_agent.callbacks.append(self.token_callback) + # For direct LLM models + elif hasattr(self.langchain_agent, "llm") and hasattr( + self.langchain_agent.llm, "callbacks" + ): + if self.langchain_agent.llm.callbacks is None: + self.langchain_agent.llm.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.llm.callbacks, list): + self.langchain_agent.llm.callbacks.append(self.token_callback) + # Direct LLM case + elif not hasattr(self.langchain_agent, "agent"): + # This might be a direct LLM, not an agent + if ( + not hasattr(self.langchain_agent, "callbacks") + or self.langchain_agent.callbacks is None + ): + self.langchain_agent.callbacks = [self.token_callback] + elif isinstance(self.langchain_agent.callbacks, list): + self.langchain_agent.callbacks.append(self.token_callback) + def execute_task( self, task: Task, @@ -181,48 +237,8 @@ class LangChainAgentAdapter(BaseAgent): else: task_prompt = self._use_trained_data(task_prompt=task_prompt) - # Initialize token tracking callback if needed - if hasattr(self, "token_process") and self.token_callback is None: - # Determine if we're using LangChain or LiteLLM based on the agent type - if hasattr(self.langchain_agent, "client") and hasattr( - self.langchain_agent.client, "callbacks" - ): - # This is likely a LiteLLM-based agent - self.token_callback = LiteLLMTokenCounter(self.token_process) - - # Add our callback to the LLM directly - if isinstance(self.langchain_agent.client.callbacks, list): - self.langchain_agent.client.callbacks.append(self.token_callback) - else: - self.langchain_agent.client.callbacks = [self.token_callback] - else: - # This is likely a LangChain-based agent - self.token_callback = LangChainTokenCounter(self.token_process) - - # Add callback to the LangChain model - if hasattr(self.langchain_agent, "callbacks"): - if self.langchain_agent.callbacks is None: - self.langchain_agent.callbacks = [self.token_callback] - elif isinstance(self.langchain_agent.callbacks, list): - self.langchain_agent.callbacks.append(self.token_callback) - # For direct LLM models - elif hasattr(self.langchain_agent, "llm") and hasattr( - self.langchain_agent.llm, "callbacks" - ): - if self.langchain_agent.llm.callbacks is None: - self.langchain_agent.llm.callbacks = [self.token_callback] - elif isinstance(self.langchain_agent.llm.callbacks, list): - self.langchain_agent.llm.callbacks.append(self.token_callback) - # Direct LLM case - elif not hasattr(self.langchain_agent, "agent"): - # This might be a direct LLM, not an agent - if ( - not hasattr(self.langchain_agent, "callbacks") - or self.langchain_agent.callbacks is None - ): - self.langchain_agent.callbacks = [self.token_callback] - elif isinstance(self.langchain_agent.callbacks, list): - self.langchain_agent.callbacks.append(self.token_callback) + # Register token tracking callback + self._register_token_callback() init_state = {"messages": [("user", task_prompt)]} @@ -408,51 +424,8 @@ class LangChainAgentAdapter(BaseAgent): agent_role = getattr(self, "role", "agent") sanitized_role = re.sub(r"\s+", "_", agent_role) - # Initialize token tracking callback if needed - if hasattr(self, "token_process") and self.token_callback is None: - # Determine if we're using LangChain or LiteLLM based on the agent type - if hasattr(self.langchain_agent, "client") and hasattr( - self.langchain_agent.client, "callbacks" - ): - # This is likely a LiteLLM-based agent - self.token_callback = LiteLLMTokenCounter(self.token_process) - - # Add our callback to the LLM directly - if isinstance(self.langchain_agent.client.callbacks, list): - if self.token_callback not in self.langchain_agent.client.callbacks: - self.langchain_agent.client.callbacks.append( - self.token_callback - ) - else: - self.langchain_agent.client.callbacks = [self.token_callback] - else: - # This is likely a LangChain-based agent - self.token_callback = LangChainTokenCounter(self.token_process) - - # Add callback to the LangChain model - if hasattr(self.langchain_agent, "callbacks"): - if self.langchain_agent.callbacks is None: - self.langchain_agent.callbacks = [self.token_callback] - elif isinstance(self.langchain_agent.callbacks, list): - self.langchain_agent.callbacks.append(self.token_callback) - # For direct LLM models - elif hasattr(self.langchain_agent, "llm") and hasattr( - self.langchain_agent.llm, "callbacks" - ): - if self.langchain_agent.llm.callbacks is None: - self.langchain_agent.llm.callbacks = [self.token_callback] - elif isinstance(self.langchain_agent.llm.callbacks, list): - self.langchain_agent.llm.callbacks.append(self.token_callback) - # Direct LLM case - elif not hasattr(self.langchain_agent, "agent"): - # This might be a direct LLM, not an agent - if ( - not hasattr(self.langchain_agent, "callbacks") - or self.langchain_agent.callbacks is None - ): - self.langchain_agent.callbacks = [self.token_callback] - elif isinstance(self.langchain_agent.callbacks, list): - self.langchain_agent.callbacks.append(self.token_callback) + # Register token tracking callback + self._register_token_callback() self.agent_executor = create_react_agent( model=self.langchain_agent,