This commit is contained in:
Brandon Hancock
2025-02-28 12:30:30 -05:00
parent 75d8e086a4
commit 6efee89399

View File

@@ -104,6 +104,62 @@ class LangChainAgentAdapter(BaseAgent):
# Default fallback
return str(message)
def _register_token_callback(self):
"""
Register the appropriate token counter callback with the language model.
This method handles different types of models (LiteLLM, LangChain, direct LLMs)
and different callback structures.
"""
# Skip if we already have a token callback registered
if self.token_callback is not None:
return
# Skip if we don't have a token_process attribute
if not hasattr(self, "token_process"):
return
# Determine if we're using LiteLLM or LangChain based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
if self.token_callback not in self.langchain_agent.client.callbacks:
self.langchain_agent.client.callbacks.append(self.token_callback)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
def execute_task(
self,
task: Task,
@@ -181,48 +237,8 @@ class LangChainAgentAdapter(BaseAgent):
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Initialize token tracking callback if needed
if hasattr(self, "token_process") and self.token_callback is None:
# Determine if we're using LangChain or LiteLLM based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
self.langchain_agent.client.callbacks.append(self.token_callback)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# Register token tracking callback
self._register_token_callback()
init_state = {"messages": [("user", task_prompt)]}
@@ -408,51 +424,8 @@ class LangChainAgentAdapter(BaseAgent):
agent_role = getattr(self, "role", "agent")
sanitized_role = re.sub(r"\s+", "_", agent_role)
# Initialize token tracking callback if needed
if hasattr(self, "token_process") and self.token_callback is None:
# Determine if we're using LangChain or LiteLLM based on the agent type
if hasattr(self.langchain_agent, "client") and hasattr(
self.langchain_agent.client, "callbacks"
):
# This is likely a LiteLLM-based agent
self.token_callback = LiteLLMTokenCounter(self.token_process)
# Add our callback to the LLM directly
if isinstance(self.langchain_agent.client.callbacks, list):
if self.token_callback not in self.langchain_agent.client.callbacks:
self.langchain_agent.client.callbacks.append(
self.token_callback
)
else:
self.langchain_agent.client.callbacks = [self.token_callback]
else:
# This is likely a LangChain-based agent
self.token_callback = LangChainTokenCounter(self.token_process)
# Add callback to the LangChain model
if hasattr(self.langchain_agent, "callbacks"):
if self.langchain_agent.callbacks is None:
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# For direct LLM models
elif hasattr(self.langchain_agent, "llm") and hasattr(
self.langchain_agent.llm, "callbacks"
):
if self.langchain_agent.llm.callbacks is None:
self.langchain_agent.llm.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.llm.callbacks, list):
self.langchain_agent.llm.callbacks.append(self.token_callback)
# Direct LLM case
elif not hasattr(self.langchain_agent, "agent"):
# This might be a direct LLM, not an agent
if (
not hasattr(self.langchain_agent, "callbacks")
or self.langchain_agent.callbacks is None
):
self.langchain_agent.callbacks = [self.token_callback]
elif isinstance(self.langchain_agent.callbacks, list):
self.langchain_agent.callbacks.append(self.token_callback)
# Register token tracking callback
self._register_token_callback()
self.agent_executor = create_react_agent(
model=self.langchain_agent,