mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
refactor
This commit is contained in:
@@ -104,6 +104,62 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
# Default fallback
|
||||
return str(message)
|
||||
|
||||
def _register_token_callback(self):
|
||||
"""
|
||||
Register the appropriate token counter callback with the language model.
|
||||
This method handles different types of models (LiteLLM, LangChain, direct LLMs)
|
||||
and different callback structures.
|
||||
"""
|
||||
# Skip if we already have a token callback registered
|
||||
if self.token_callback is not None:
|
||||
return
|
||||
|
||||
# Skip if we don't have a token_process attribute
|
||||
if not hasattr(self, "token_process"):
|
||||
return
|
||||
|
||||
# Determine if we're using LiteLLM or LangChain based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
@@ -181,48 +237,8 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Initialize token tracking callback if needed
|
||||
if hasattr(self, "token_process") and self.token_callback is None:
|
||||
# Determine if we're using LangChain or LiteLLM based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# Register token tracking callback
|
||||
self._register_token_callback()
|
||||
|
||||
init_state = {"messages": [("user", task_prompt)]}
|
||||
|
||||
@@ -408,51 +424,8 @@ class LangChainAgentAdapter(BaseAgent):
|
||||
agent_role = getattr(self, "role", "agent")
|
||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||
|
||||
# Initialize token tracking callback if needed
|
||||
if hasattr(self, "token_process") and self.token_callback is None:
|
||||
# Determine if we're using LangChain or LiteLLM based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||
self.langchain_agent.client.callbacks.append(
|
||||
self.token_callback
|
||||
)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# Register token tracking callback
|
||||
self._register_token_callback()
|
||||
|
||||
self.agent_executor = create_react_agent(
|
||||
model=self.langchain_agent,
|
||||
|
||||
Reference in New Issue
Block a user