mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
refactor
This commit is contained in:
@@ -104,6 +104,62 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
# Default fallback
|
# Default fallback
|
||||||
return str(message)
|
return str(message)
|
||||||
|
|
||||||
|
def _register_token_callback(self):
|
||||||
|
"""
|
||||||
|
Register the appropriate token counter callback with the language model.
|
||||||
|
This method handles different types of models (LiteLLM, LangChain, direct LLMs)
|
||||||
|
and different callback structures.
|
||||||
|
"""
|
||||||
|
# Skip if we already have a token callback registered
|
||||||
|
if self.token_callback is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip if we don't have a token_process attribute
|
||||||
|
if not hasattr(self, "token_process"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine if we're using LiteLLM or LangChain based on the agent type
|
||||||
|
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||||
|
self.langchain_agent.client, "callbacks"
|
||||||
|
):
|
||||||
|
# This is likely a LiteLLM-based agent
|
||||||
|
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add our callback to the LLM directly
|
||||||
|
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||||
|
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||||
|
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||||
|
else:
|
||||||
|
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||||
|
else:
|
||||||
|
# This is likely a LangChain-based agent
|
||||||
|
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add callback to the LangChain model
|
||||||
|
if hasattr(self.langchain_agent, "callbacks"):
|
||||||
|
if self.langchain_agent.callbacks is None:
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
# For direct LLM models
|
||||||
|
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||||
|
self.langchain_agent.llm, "callbacks"
|
||||||
|
):
|
||||||
|
if self.langchain_agent.llm.callbacks is None:
|
||||||
|
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||||
|
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||||
|
# Direct LLM case
|
||||||
|
elif not hasattr(self.langchain_agent, "agent"):
|
||||||
|
# This might be a direct LLM, not an agent
|
||||||
|
if (
|
||||||
|
not hasattr(self.langchain_agent, "callbacks")
|
||||||
|
or self.langchain_agent.callbacks is None
|
||||||
|
):
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
|
||||||
def execute_task(
|
def execute_task(
|
||||||
self,
|
self,
|
||||||
task: Task,
|
task: Task,
|
||||||
@@ -181,48 +237,8 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
else:
|
else:
|
||||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
# Initialize token tracking callback if needed
|
# Register token tracking callback
|
||||||
if hasattr(self, "token_process") and self.token_callback is None:
|
self._register_token_callback()
|
||||||
# Determine if we're using LangChain or LiteLLM based on the agent type
|
|
||||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
|
||||||
self.langchain_agent.client, "callbacks"
|
|
||||||
):
|
|
||||||
# This is likely a LiteLLM-based agent
|
|
||||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
|
||||||
|
|
||||||
# Add our callback to the LLM directly
|
|
||||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
|
||||||
self.langchain_agent.client.callbacks.append(self.token_callback)
|
|
||||||
else:
|
|
||||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
|
||||||
else:
|
|
||||||
# This is likely a LangChain-based agent
|
|
||||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
|
||||||
|
|
||||||
# Add callback to the LangChain model
|
|
||||||
if hasattr(self.langchain_agent, "callbacks"):
|
|
||||||
if self.langchain_agent.callbacks is None:
|
|
||||||
self.langchain_agent.callbacks = [self.token_callback]
|
|
||||||
elif isinstance(self.langchain_agent.callbacks, list):
|
|
||||||
self.langchain_agent.callbacks.append(self.token_callback)
|
|
||||||
# For direct LLM models
|
|
||||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
|
||||||
self.langchain_agent.llm, "callbacks"
|
|
||||||
):
|
|
||||||
if self.langchain_agent.llm.callbacks is None:
|
|
||||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
|
||||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
|
||||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
|
||||||
# Direct LLM case
|
|
||||||
elif not hasattr(self.langchain_agent, "agent"):
|
|
||||||
# This might be a direct LLM, not an agent
|
|
||||||
if (
|
|
||||||
not hasattr(self.langchain_agent, "callbacks")
|
|
||||||
or self.langchain_agent.callbacks is None
|
|
||||||
):
|
|
||||||
self.langchain_agent.callbacks = [self.token_callback]
|
|
||||||
elif isinstance(self.langchain_agent.callbacks, list):
|
|
||||||
self.langchain_agent.callbacks.append(self.token_callback)
|
|
||||||
|
|
||||||
init_state = {"messages": [("user", task_prompt)]}
|
init_state = {"messages": [("user", task_prompt)]}
|
||||||
|
|
||||||
@@ -408,51 +424,8 @@ class LangChainAgentAdapter(BaseAgent):
|
|||||||
agent_role = getattr(self, "role", "agent")
|
agent_role = getattr(self, "role", "agent")
|
||||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||||
|
|
||||||
# Initialize token tracking callback if needed
|
# Register token tracking callback
|
||||||
if hasattr(self, "token_process") and self.token_callback is None:
|
self._register_token_callback()
|
||||||
# Determine if we're using LangChain or LiteLLM based on the agent type
|
|
||||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
|
||||||
self.langchain_agent.client, "callbacks"
|
|
||||||
):
|
|
||||||
# This is likely a LiteLLM-based agent
|
|
||||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
|
||||||
|
|
||||||
# Add our callback to the LLM directly
|
|
||||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
|
||||||
if self.token_callback not in self.langchain_agent.client.callbacks:
|
|
||||||
self.langchain_agent.client.callbacks.append(
|
|
||||||
self.token_callback
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
|
||||||
else:
|
|
||||||
# This is likely a LangChain-based agent
|
|
||||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
|
||||||
|
|
||||||
# Add callback to the LangChain model
|
|
||||||
if hasattr(self.langchain_agent, "callbacks"):
|
|
||||||
if self.langchain_agent.callbacks is None:
|
|
||||||
self.langchain_agent.callbacks = [self.token_callback]
|
|
||||||
elif isinstance(self.langchain_agent.callbacks, list):
|
|
||||||
self.langchain_agent.callbacks.append(self.token_callback)
|
|
||||||
# For direct LLM models
|
|
||||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
|
||||||
self.langchain_agent.llm, "callbacks"
|
|
||||||
):
|
|
||||||
if self.langchain_agent.llm.callbacks is None:
|
|
||||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
|
||||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
|
||||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
|
||||||
# Direct LLM case
|
|
||||||
elif not hasattr(self.langchain_agent, "agent"):
|
|
||||||
# This might be a direct LLM, not an agent
|
|
||||||
if (
|
|
||||||
not hasattr(self.langchain_agent, "callbacks")
|
|
||||||
or self.langchain_agent.callbacks is None
|
|
||||||
):
|
|
||||||
self.langchain_agent.callbacks = [self.token_callback]
|
|
||||||
elif isinstance(self.langchain_agent.callbacks, list):
|
|
||||||
self.langchain_agent.callbacks.append(self.token_callback)
|
|
||||||
|
|
||||||
self.agent_executor = create_react_agent(
|
self.agent_executor = create_react_agent(
|
||||||
model=self.langchain_agent,
|
model=self.langchain_agent,
|
||||||
|
|||||||
Reference in New Issue
Block a user