mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 21:38:14 +00:00
Compare commits
52 Commits
devin/1768
...
brandon/br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6efee89399 | ||
|
|
75d8e086a4 | ||
|
|
ef48cbe971 | ||
|
|
88d8079dcd | ||
|
|
33ef612cd5 | ||
|
|
6c81acac00 | ||
|
|
2181166a62 | ||
|
|
3f17789152 | ||
|
|
ed0b8e1563 | ||
|
|
4915420ed2 | ||
|
|
271faa917b | ||
|
|
e87eb57edc | ||
|
|
e6e4bb15d7 | ||
|
|
92a3349d64 | ||
|
|
940f0647a9 | ||
|
|
5db512bef6 | ||
|
|
48d2b8c320 | ||
|
|
f6b0b492a4 | ||
|
|
4b63b29787 | ||
|
|
860efc3b42 | ||
|
|
70ab4ad003 | ||
|
|
6fb25a1af7 | ||
|
|
ca9277ae4c | ||
|
|
e58e544304 | ||
|
|
84e0a9e686 | ||
|
|
6e0e9b30fe | ||
|
|
f6393fd088 | ||
|
|
64804682fc | ||
|
|
cc7669ab39 | ||
|
|
377b64ac81 | ||
|
|
470254c3e2 | ||
|
|
497190f823 | ||
|
|
241adb8ed0 | ||
|
|
e60c6e66a4 | ||
|
|
afd01e3c0c | ||
|
|
81b8ae0abd | ||
|
|
ae19437473 | ||
|
|
f8c74b4fbb | ||
|
|
134e7ab241 | ||
|
|
74e63621a5 | ||
|
|
8953af6133 | ||
|
|
2b438baad4 | ||
|
|
185556b7e3 | ||
|
|
1e23d37a14 | ||
|
|
df21f01441 | ||
|
|
b957fc1a18 | ||
|
|
fd0e1bdd1a | ||
|
|
265b37316b | ||
|
|
ff32880a54 | ||
|
|
a38483e1b4 | ||
|
|
7910dc9337 | ||
|
|
796e50aba8 |
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
from typing import Any, Dict, List, Literal, Optional, Sequence, Union, cast
|
||||||
|
|
||||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||||
|
|
||||||
@@ -170,27 +170,19 @@ class Agent(BaseAgent):
|
|||||||
Output of the agent
|
Output of the agent
|
||||||
"""
|
"""
|
||||||
if self.tools_handler:
|
if self.tools_handler:
|
||||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalli
|
||||||
|
|
||||||
task_prompt = task.prompt()
|
task_prompt = task.prompt()
|
||||||
|
|
||||||
# If the task requires output in JSON or Pydantic format,
|
# If the task requires output in JSON or Pydantic format,
|
||||||
# append specific instructions to the task prompt to ensure
|
# append specific instructions to the task prompt to ensure
|
||||||
# that the final answer does not include any code block markers
|
# that the final answer does not include any code block markers
|
||||||
if task.output_json or task.output_pydantic:
|
if task.output_json or task.output_pydantic:
|
||||||
# Generate the schema based on the output format
|
# Choose the output format, preferring output_json if available
|
||||||
if task.output_json:
|
output_format = (
|
||||||
# schema = json.dumps(task.output_json, indent=2)
|
task.output_json if task.output_json else task.output_pydantic
|
||||||
schema = generate_model_description(task.output_json)
|
)
|
||||||
task_prompt += "\n" + self.i18n.slice(
|
schema = generate_model_description(cast(type, output_format))
|
||||||
"formatted_task_instructions"
|
task_prompt += f"\n{self.i18n.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||||
).format(output_format=schema)
|
|
||||||
|
|
||||||
elif task.output_pydantic:
|
|
||||||
schema = generate_model_description(task.output_pydantic)
|
|
||||||
task_prompt += "\n" + self.i18n.slice(
|
|
||||||
"formatted_task_instructions"
|
|
||||||
).format(output_format=schema)
|
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
task_prompt = self.i18n.slice("task_with_context").format(
|
task_prompt = self.i18n.slice("task_with_context").format(
|
||||||
@@ -276,9 +268,6 @@ class Agent(BaseAgent):
|
|||||||
raise e
|
raise e
|
||||||
result = self.execute_task(task, context, tools)
|
result = self.execute_task(task, context, tools)
|
||||||
|
|
||||||
if self.max_rpm and self._rpm_controller:
|
|
||||||
self._rpm_controller.stop_rpm_counter()
|
|
||||||
|
|
||||||
# If there was any tool in self.tools_results that had result_as_answer
|
# If there was any tool in self.tools_results that had result_as_answer
|
||||||
# set to True, return the results of the last tool that had
|
# set to True, return the results of the last tool that had
|
||||||
# result_as_answer set to True
|
# result_as_answer set to True
|
||||||
@@ -338,7 +327,7 @@ class Agent(BaseAgent):
|
|||||||
request_within_rpm_limit=(
|
request_within_rpm_limit=(
|
||||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||||
),
|
),
|
||||||
callbacks=[TokenCalcHandler(self._token_process)],
|
callbacks=[TokenCalcHandler(self.token_process)],
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||||
|
|||||||
@@ -73,20 +73,27 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
Increment formatting errors.
|
Increment formatting errors.
|
||||||
copy() -> "BaseAgent":
|
copy() -> "BaseAgent":
|
||||||
Create a copy of the agent.
|
Create a copy of the agent.
|
||||||
set_rpm_controller(rpm_controller: RPMController) -> None:
|
set_rpm_controller(rpm_controller: Optional[RPMController] = None) -> None:
|
||||||
Set the rpm controller for the agent.
|
Set the rpm controller for the agent.
|
||||||
set_private_attrs() -> "BaseAgent":
|
set_private_attrs() -> "BaseAgent":
|
||||||
Set private attributes.
|
Set private attributes.
|
||||||
|
configure_executor(cache_handler: CacheHandler, rpm_controller: RPMController) -> None:
|
||||||
|
Configure the agent's executor with both cache and RPM handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__hash__ = object.__hash__ # type: ignore
|
__hash__ = object.__hash__ # type: ignore
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
|
|
||||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
formatting_errors: int = Field(
|
formatting_errors: int = Field(
|
||||||
default=0, description="Number of formatting errors."
|
default=0, description="Number of formatting errors."
|
||||||
@@ -196,8 +203,6 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
self._rpm_controller = RPMController(
|
self._rpm_controller = RPMController(
|
||||||
max_rpm=self.max_rpm, logger=self._logger
|
max_rpm=self.max_rpm, logger=self._logger
|
||||||
)
|
)
|
||||||
if not self._token_process:
|
|
||||||
self._token_process = TokenProcess()
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -217,8 +222,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
self._rpm_controller = RPMController(
|
self._rpm_controller = RPMController(
|
||||||
max_rpm=self.max_rpm, logger=self._logger
|
max_rpm=self.max_rpm, logger=self._logger
|
||||||
)
|
)
|
||||||
if not self._token_process:
|
|
||||||
self._token_process = TokenProcess()
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -266,7 +270,7 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"_logger",
|
"_logger",
|
||||||
"_rpm_controller",
|
"_rpm_controller",
|
||||||
"_request_within_rpm_limit",
|
"_request_within_rpm_limit",
|
||||||
"_token_process",
|
"token_process",
|
||||||
"agent_executor",
|
"agent_executor",
|
||||||
"tools",
|
"tools",
|
||||||
"tools_handler",
|
"tools_handler",
|
||||||
@@ -337,20 +341,49 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
if self.cache:
|
if self.cache:
|
||||||
self.cache_handler = cache_handler
|
self.cache_handler = cache_handler
|
||||||
self.tools_handler.cache = cache_handler
|
self.tools_handler.cache = cache_handler
|
||||||
self.create_agent_executor()
|
# Only create the executor if it hasn't been created yet.
|
||||||
|
if self.agent_executor is None:
|
||||||
|
self.create_agent_executor()
|
||||||
|
|
||||||
def increment_formatting_errors(self) -> None:
|
def increment_formatting_errors(self) -> None:
|
||||||
self.formatting_errors += 1
|
self.formatting_errors += 1
|
||||||
|
|
||||||
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
|
def set_rpm_controller(
|
||||||
"""Set the rpm controller for the agent.
|
self, rpm_controller: Optional[RPMController] = None
|
||||||
|
) -> None:
|
||||||
Args:
|
|
||||||
rpm_controller: An instance of the RPMController class.
|
|
||||||
"""
|
"""
|
||||||
if not self._rpm_controller:
|
Set the RPM controller for the agent. If no rpm_controller is provided, then:
|
||||||
self._rpm_controller = rpm_controller
|
- use self.max_rpm if set, or
|
||||||
self.create_agent_executor()
|
- if self.crew exists and has max_rpm, use that.
|
||||||
|
"""
|
||||||
|
if self._rpm_controller is None:
|
||||||
|
if rpm_controller is not None:
|
||||||
|
self._rpm_controller = rpm_controller
|
||||||
|
elif self.max_rpm:
|
||||||
|
self._rpm_controller = RPMController(
|
||||||
|
max_rpm=self.max_rpm, logger=self._logger
|
||||||
|
)
|
||||||
|
elif self.crew and getattr(self.crew, "max_rpm", None):
|
||||||
|
self._rpm_controller = RPMController(
|
||||||
|
max_rpm=self.crew.max_rpm, logger=self._logger
|
||||||
|
)
|
||||||
|
# else: no rpm limit provided – leave the controller None
|
||||||
|
if self.agent_executor is None:
|
||||||
|
self.create_agent_executor()
|
||||||
|
|
||||||
|
def configure_executor(
|
||||||
|
self, cache_handler: CacheHandler, rpm_controller: Optional[RPMController]
|
||||||
|
) -> None:
|
||||||
|
"""Configure the agent's executor with both cache and RPM handling.
|
||||||
|
|
||||||
|
This method delegates to set_cache_handler and set_rpm_controller, applying the configuration
|
||||||
|
only if the respective flags or values are set.
|
||||||
|
"""
|
||||||
|
if self.cache:
|
||||||
|
self.set_cache_handler(cache_handler)
|
||||||
|
# Use the injected RPM controller rather than auto-creating one
|
||||||
|
if rpm_controller:
|
||||||
|
self.set_rpm_controller(rpm_controller)
|
||||||
|
|
||||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
tool.name: tool for tool in self.tools
|
tool.name: tool for tool in self.tools
|
||||||
}
|
}
|
||||||
self.stop = stop_words
|
self.stop = stop_words
|
||||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
self.llm.stop = list(set((self.llm.stop or []) + self.stop))
|
||||||
|
|
||||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||||
if "system" in self.prompt:
|
if "system" in self.prompt:
|
||||||
|
|||||||
468
src/crewai/agents/langchain_agent_adapter.py
Normal file
468
src/crewai/agents/langchain_agent_adapter.py
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
from typing import Any, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
from crewai.tools.base_tool import Tool
|
||||||
|
from crewai.utilities.converter import Converter, generate_model_description
|
||||||
|
from crewai.utilities.token_counter_callback import (
|
||||||
|
LangChainTokenCounter,
|
||||||
|
LiteLLMTokenCounter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainAgentAdapter(BaseAgent):
|
||||||
|
"""
|
||||||
|
Adapter class to wrap a LangChain agent and make it compatible with CrewAI's BaseAgent interface.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- This adapter does not require LangChain as a dependency.
|
||||||
|
- It wraps an external LangChain agent (passed as any type) and delegates calls
|
||||||
|
such as execute_task() to the LangChain agent's invoke() method.
|
||||||
|
- Extended logic is added to build prompts, incorporate memory, knowledge, training hints,
|
||||||
|
and now a human feedback loop similar to what is done in CrewAgentExecutor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
langchain_agent: Any = Field(
|
||||||
|
...,
|
||||||
|
description="The wrapped LangChain runnable agent instance. It is expected to have an 'invoke' method.",
|
||||||
|
)
|
||||||
|
tools: Optional[List[Union[BaseTool, Any]]] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Tools at the agent's disposal. Accepts both CrewAI BaseTool instances and other tools.",
|
||||||
|
)
|
||||||
|
function_calling_llm: Optional[Any] = Field(
|
||||||
|
default=None, description="Optional function calling LLM."
|
||||||
|
)
|
||||||
|
step_callback: Optional[Any] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Callback executed after each step of agent execution.",
|
||||||
|
)
|
||||||
|
allow_code_execution: Optional[bool] = Field(
|
||||||
|
default=False, description="Enable code execution for the agent."
|
||||||
|
)
|
||||||
|
multimodal: bool = Field(
|
||||||
|
default=False, description="Whether the agent is multimodal."
|
||||||
|
)
|
||||||
|
i18n: Any = None
|
||||||
|
crew: Any = None
|
||||||
|
knowledge: Any = None
|
||||||
|
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||||
|
token_callback: Optional[Any] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@field_validator("tools", mode="before")
|
||||||
|
def convert_tools(cls, value):
|
||||||
|
"""Ensure tools are valid CrewAI BaseTool instances."""
|
||||||
|
if not value:
|
||||||
|
return value
|
||||||
|
new_tools = []
|
||||||
|
for tool in value:
|
||||||
|
# If tool is already a CrewAI BaseTool instance, keep it as is.
|
||||||
|
if isinstance(tool, BaseTool):
|
||||||
|
new_tools.append(tool)
|
||||||
|
else:
|
||||||
|
new_tools.append(Tool.from_langchain(tool))
|
||||||
|
return new_tools
|
||||||
|
|
||||||
|
def _extract_text(self, message: Any) -> str:
|
||||||
|
"""
|
||||||
|
Helper to extract plain text from a message object.
|
||||||
|
This checks if the message is a dict with a "content" key, or has a "content" attribute,
|
||||||
|
or if it's a tuple from LangGraph's message format.
|
||||||
|
"""
|
||||||
|
# Handle LangGraph message tuple format (role, content)
|
||||||
|
if isinstance(message, tuple) and len(message) == 2:
|
||||||
|
return str(message[1])
|
||||||
|
|
||||||
|
# Handle dictionary with content key
|
||||||
|
elif isinstance(message, dict):
|
||||||
|
if "content" in message:
|
||||||
|
return message["content"]
|
||||||
|
# Handle LangGraph message format with additional metadata
|
||||||
|
elif "messages" in message and message["messages"]:
|
||||||
|
last_message = message["messages"][-1]
|
||||||
|
if isinstance(last_message, tuple) and len(last_message) == 2:
|
||||||
|
return str(last_message[1])
|
||||||
|
return self._extract_text(last_message)
|
||||||
|
|
||||||
|
# Handle object with content attribute
|
||||||
|
elif hasattr(message, "content") and isinstance(
|
||||||
|
getattr(message, "content"), str
|
||||||
|
):
|
||||||
|
return getattr(message, "content")
|
||||||
|
|
||||||
|
# Handle string directly
|
||||||
|
elif isinstance(message, str):
|
||||||
|
return message
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
|
return str(message)
|
||||||
|
|
||||||
|
def _register_token_callback(self):
|
||||||
|
"""
|
||||||
|
Register the appropriate token counter callback with the language model.
|
||||||
|
This method handles different types of models (LiteLLM, LangChain, direct LLMs)
|
||||||
|
and different callback structures.
|
||||||
|
"""
|
||||||
|
# Skip if we already have a token callback registered
|
||||||
|
if self.token_callback is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip if we don't have a token_process attribute
|
||||||
|
if not hasattr(self, "token_process"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine if we're using LiteLLM or LangChain based on the agent type
|
||||||
|
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||||
|
self.langchain_agent.client, "callbacks"
|
||||||
|
):
|
||||||
|
# This is likely a LiteLLM-based agent
|
||||||
|
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add our callback to the LLM directly
|
||||||
|
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||||
|
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||||
|
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||||
|
else:
|
||||||
|
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||||
|
else:
|
||||||
|
# This is likely a LangChain-based agent
|
||||||
|
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||||
|
|
||||||
|
# Add callback to the LangChain model
|
||||||
|
if hasattr(self.langchain_agent, "callbacks"):
|
||||||
|
if self.langchain_agent.callbacks is None:
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
# For direct LLM models
|
||||||
|
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||||
|
self.langchain_agent.llm, "callbacks"
|
||||||
|
):
|
||||||
|
if self.langchain_agent.llm.callbacks is None:
|
||||||
|
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||||
|
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||||
|
# Direct LLM case
|
||||||
|
elif not hasattr(self.langchain_agent, "agent"):
|
||||||
|
# This might be a direct LLM, not an agent
|
||||||
|
if (
|
||||||
|
not hasattr(self.langchain_agent, "callbacks")
|
||||||
|
or self.langchain_agent.callbacks is None
|
||||||
|
):
|
||||||
|
self.langchain_agent.callbacks = [self.token_callback]
|
||||||
|
elif isinstance(self.langchain_agent.callbacks, list):
|
||||||
|
self.langchain_agent.callbacks.append(self.token_callback)
|
||||||
|
|
||||||
|
def execute_task(
|
||||||
|
self,
|
||||||
|
task: Task,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
tools: Optional[List[BaseTool]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Execute a task by building the full task prompt (with memory, knowledge, tool instructions,
|
||||||
|
and training hints) then delegating execution to the wrapped LangChain agent.
|
||||||
|
If the task requires human input, a feedback loop is run that mimics the CrewAgentExecutor.
|
||||||
|
"""
|
||||||
|
task_prompt = task.prompt()
|
||||||
|
|
||||||
|
if task.output_json or task.output_pydantic:
|
||||||
|
# Choose the output format, preferring output_json if available
|
||||||
|
output_format = (
|
||||||
|
task.output_json if task.output_json else task.output_pydantic
|
||||||
|
)
|
||||||
|
schema = generate_model_description(cast(type, output_format))
|
||||||
|
instruction = self.i18n.slice("formatted_task_instructions").format(
|
||||||
|
output_format=schema
|
||||||
|
)
|
||||||
|
task_prompt += f"\n{instruction}"
|
||||||
|
|
||||||
|
if context:
|
||||||
|
task_prompt = self.i18n.slice("task_with_context").format(
|
||||||
|
task=task_prompt, context=context
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.crew and self.crew.memory:
|
||||||
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
self.crew.memory_config,
|
||||||
|
self.crew._short_term_memory,
|
||||||
|
self.crew._long_term_memory,
|
||||||
|
self.crew._entity_memory,
|
||||||
|
self.crew._user_memory,
|
||||||
|
)
|
||||||
|
memory = contextual_memory.build_context_for_task(task, context)
|
||||||
|
if memory.strip():
|
||||||
|
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||||
|
|
||||||
|
if self.knowledge:
|
||||||
|
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
|
||||||
|
if agent_knowledge_snippets:
|
||||||
|
from crewai.knowledge.utils.knowledge_utils import (
|
||||||
|
extract_knowledge_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_knowledge_context = extract_knowledge_context(
|
||||||
|
agent_knowledge_snippets
|
||||||
|
)
|
||||||
|
if agent_knowledge_context:
|
||||||
|
task_prompt += agent_knowledge_context
|
||||||
|
|
||||||
|
if self.crew:
|
||||||
|
knowledge_snippets = self.crew.query_knowledge([task.prompt()])
|
||||||
|
if knowledge_snippets:
|
||||||
|
from crewai.knowledge.utils.knowledge_utils import (
|
||||||
|
extract_knowledge_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew_knowledge_context = extract_knowledge_context(knowledge_snippets)
|
||||||
|
if crew_knowledge_context:
|
||||||
|
task_prompt += crew_knowledge_context
|
||||||
|
|
||||||
|
tools = tools or self.tools or []
|
||||||
|
self.create_agent_executor(tools=tools)
|
||||||
|
|
||||||
|
self._show_start_logs(task)
|
||||||
|
|
||||||
|
if self.crew and getattr(self.crew, "_train", False):
|
||||||
|
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||||
|
else:
|
||||||
|
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
# Register token tracking callback
|
||||||
|
self._register_token_callback()
|
||||||
|
|
||||||
|
init_state = {"messages": [("user", task_prompt)]}
|
||||||
|
|
||||||
|
# Estimate input tokens for tracking
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters (better than word count)
|
||||||
|
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
|
||||||
|
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||||
|
|
||||||
|
state = self.agent_executor.invoke(init_state)
|
||||||
|
|
||||||
|
# Extract output from state based on its structure
|
||||||
|
if "structured_response" in state:
|
||||||
|
current_output = state["structured_response"]
|
||||||
|
elif "messages" in state and state["messages"]:
|
||||||
|
last_message = state["messages"][-1]
|
||||||
|
current_output = self._extract_text(last_message)
|
||||||
|
elif "output" in state:
|
||||||
|
current_output = str(state["output"])
|
||||||
|
else:
|
||||||
|
# Fallback to extracting text from the entire state
|
||||||
|
current_output = self._extract_text(state)
|
||||||
|
|
||||||
|
# Estimate completion tokens for tracking if we don't have actual counts
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters
|
||||||
|
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
|
||||||
|
self.token_process.sum_completion_tokens(estimated_completion_tokens)
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
|
if task.human_input:
|
||||||
|
current_output = self._handle_human_feedback(current_output)
|
||||||
|
|
||||||
|
return current_output
|
||||||
|
|
||||||
|
def _handle_human_feedback(self, current_output: str) -> str:
|
||||||
|
"""
|
||||||
|
Implements a feedback loop that prompts the user for feedback and then instructs
|
||||||
|
the underlying LangChain agent to regenerate its answer with the requested changes.
|
||||||
|
Only the inner content of the output is displayed to the user.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
print("\nAgent output:")
|
||||||
|
# Print only the inner text extracted from current_output.
|
||||||
|
print(self._extract_text(current_output))
|
||||||
|
|
||||||
|
feedback = input("\nEnter your feedback (or press Enter to accept): ")
|
||||||
|
if not feedback.strip():
|
||||||
|
break # No feedback provided, exit the loop
|
||||||
|
|
||||||
|
extracted_output = self._extract_text(current_output)
|
||||||
|
new_prompt = (
|
||||||
|
f"Below is your previous answer:\n"
|
||||||
|
f"{extracted_output}\n\n"
|
||||||
|
f"Based on the following feedback: '{feedback}', please regenerate your answer with the requested details. "
|
||||||
|
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
||||||
|
f"Updated answer:"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Estimate input tokens for tracking
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters
|
||||||
|
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
|
||||||
|
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_state = self.agent_executor.invoke(
|
||||||
|
{"messages": [("user", new_prompt)]}
|
||||||
|
)
|
||||||
|
# Extract output from state based on its structure
|
||||||
|
if "structured_response" in new_state:
|
||||||
|
new_output = new_state["structured_response"]
|
||||||
|
elif "messages" in new_state and new_state["messages"]:
|
||||||
|
last_message = new_state["messages"][-1]
|
||||||
|
new_output = self._extract_text(last_message)
|
||||||
|
elif "output" in new_state:
|
||||||
|
new_output = str(new_state["output"])
|
||||||
|
else:
|
||||||
|
# Fallback to extracting text from the entire state
|
||||||
|
new_output = self._extract_text(new_state)
|
||||||
|
|
||||||
|
# Estimate completion tokens for tracking
|
||||||
|
if hasattr(self, "token_process"):
|
||||||
|
# Rough estimate based on characters
|
||||||
|
estimated_completion_tokens = (
|
||||||
|
len(new_output) // 4
|
||||||
|
) # ~4 chars per token
|
||||||
|
self.token_process.sum_completion_tokens(
|
||||||
|
estimated_completion_tokens
|
||||||
|
)
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
|
current_output = new_output
|
||||||
|
except Exception as e:
|
||||||
|
print("Error during re-invocation with feedback:", e)
|
||||||
|
break
|
||||||
|
|
||||||
|
return current_output
|
||||||
|
|
||||||
|
def _generate_model_description(self, model: Any) -> str:
|
||||||
|
"""
|
||||||
|
Generates a string description (schema) for the expected output.
|
||||||
|
This is a placeholder that should call the actual implementation.
|
||||||
|
"""
|
||||||
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
|
||||||
|
return generate_model_description(model)
|
||||||
|
|
||||||
|
def _training_handler(self, task_prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Append training instructions from Crew data to the task prompt.
|
||||||
|
"""
|
||||||
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
|
data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||||
|
if data:
|
||||||
|
agent_id = str(self.id)
|
||||||
|
if data.get(agent_id):
|
||||||
|
human_feedbacks = [
|
||||||
|
i["human_feedback"] for i in data.get(agent_id, {}).values()
|
||||||
|
]
|
||||||
|
task_prompt += (
|
||||||
|
"\n\nYou MUST follow these instructions: \n "
|
||||||
|
+ "\n - ".join(human_feedbacks)
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
def _use_trained_data(self, task_prompt: str) -> str:
|
||||||
|
"""
|
||||||
|
Append pre-trained instructions from Crew data to the task prompt.
|
||||||
|
"""
|
||||||
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE
|
||||||
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
|
data = CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load()
|
||||||
|
if data and (trained_data_output := data.get(getattr(self, "role", "default"))):
|
||||||
|
task_prompt += (
|
||||||
|
"\n\nYou MUST follow these instructions: \n - "
|
||||||
|
+ "\n - ".join(trained_data_output["suggestions"])
|
||||||
|
)
|
||||||
|
return task_prompt
|
||||||
|
|
||||||
|
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Creates an agent executor using LangGraph's create_react_agent if given an LLM,
|
||||||
|
or uses the provided language model directly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"LangGraph library not found. Please run `uv add langgraph` to add LangGraph support."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Ensure raw_tools is always a list.
|
||||||
|
raw_tools: List[Any] = (
|
||||||
|
tools
|
||||||
|
if tools is not None
|
||||||
|
else (self.tools if self.tools is not None else [])
|
||||||
|
)
|
||||||
|
# Fallback: if raw_tools is still empty, try to extract them from the wrapped LangChain agent.
|
||||||
|
if not raw_tools:
|
||||||
|
if hasattr(self.langchain_agent, "agent") and hasattr(
|
||||||
|
self.langchain_agent.agent, "tools"
|
||||||
|
):
|
||||||
|
raw_tools = self.langchain_agent.agent.tools or []
|
||||||
|
else:
|
||||||
|
raw_tools = getattr(self.langchain_agent, "tools", []) or []
|
||||||
|
|
||||||
|
used_tools = []
|
||||||
|
# Use the global CrewAI Tool class (imported at the module level)
|
||||||
|
for tool in raw_tools:
|
||||||
|
# If the tool is a CrewAI Tool, convert it to a LangChain compatible tool.
|
||||||
|
if isinstance(tool, Tool):
|
||||||
|
used_tools.append(tool.to_langchain())
|
||||||
|
else:
|
||||||
|
used_tools.append(tool)
|
||||||
|
|
||||||
|
# Sanitize the agent's role for the "name" field. The allowed pattern is ^[a-zA-Z0-9_-]+$
|
||||||
|
import re
|
||||||
|
|
||||||
|
agent_role = getattr(self, "role", "agent")
|
||||||
|
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||||
|
|
||||||
|
# Register token tracking callback
|
||||||
|
self._register_token_callback()
|
||||||
|
|
||||||
|
self.agent_executor = create_react_agent(
|
||||||
|
model=self.langchain_agent,
|
||||||
|
tools=used_tools,
|
||||||
|
debug=getattr(self, "verbose", False),
|
||||||
|
name=sanitized_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_tools(self, tools: List[BaseTool]) -> List[BaseTool]:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_output_converter(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
text: str,
|
||||||
|
model: Optional[Type] = None,
|
||||||
|
instructions: str = "",
|
||||||
|
) -> Converter:
|
||||||
|
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||||
|
|
||||||
|
def _show_start_logs(self, task: Task) -> None:
|
||||||
|
if self.langchain_agent is None:
|
||||||
|
raise ValueError("Agent cannot be None")
|
||||||
|
# Check if the adapter or its crew is in verbose mode.
|
||||||
|
verbose = self.verbose or (self.crew and getattr(self.crew, "verbose", False))
|
||||||
|
if verbose:
|
||||||
|
from crewai.utilities import Printer
|
||||||
|
|
||||||
|
printer = Printer()
|
||||||
|
# Use the adapter's role (inherited from BaseAgent) for logging.
|
||||||
|
printer.print(
|
||||||
|
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{self.role}\033[00m"
|
||||||
|
)
|
||||||
|
description = getattr(task, "description", "Not Found")
|
||||||
|
printer.print(
|
||||||
|
content=f"\033[95m## Task:\033[00m \033[92m{description}\033[00m"
|
||||||
|
)
|
||||||
@@ -94,7 +94,7 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
__hash__ = object.__hash__ # type: ignore
|
__hash__ = object.__hash__ # type: ignore
|
||||||
_execution_span: Any = PrivateAttr()
|
_execution_span: Any = PrivateAttr()
|
||||||
_rpm_controller: RPMController = PrivateAttr()
|
_rpm_controller: Optional[RPMController] = PrivateAttr()
|
||||||
_logger: Logger = PrivateAttr()
|
_logger: Logger = PrivateAttr()
|
||||||
_file_handler: FileHandler = PrivateAttr()
|
_file_handler: FileHandler = PrivateAttr()
|
||||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||||
@@ -248,7 +248,6 @@ class Crew(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_private_attrs(self) -> "Crew":
|
def set_private_attrs(self) -> "Crew":
|
||||||
"""Set private attributes."""
|
"""Set private attributes."""
|
||||||
self._cache_handler = CacheHandler()
|
|
||||||
self._logger = Logger(verbose=self.verbose)
|
self._logger = Logger(verbose=self.verbose)
|
||||||
if self.output_log_file:
|
if self.output_log_file:
|
||||||
self._file_handler = FileHandler(self.output_log_file)
|
self._file_handler = FileHandler(self.output_log_file)
|
||||||
@@ -258,6 +257,24 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def initialize_dependencies(self) -> "Crew":
|
||||||
|
# Always create a cache handler, but it will only be used if self.cache is True
|
||||||
|
# Create the Crew-level RPM controller if a max RPM is specified
|
||||||
|
if self.max_rpm is not None:
|
||||||
|
self._rpm_controller = RPMController(
|
||||||
|
max_rpm=self.max_rpm, logger=Logger(verbose=self.verbose)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._rpm_controller = None
|
||||||
|
|
||||||
|
# Now inject these external dependencies into each agent
|
||||||
|
for agent in self.agents:
|
||||||
|
agent.crew = self # ensure the agent's crew reference is set
|
||||||
|
agent.configure_executor(self._cache_handler, self._rpm_controller)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def create_crew_memory(self) -> "Crew":
|
def create_crew_memory(self) -> "Crew":
|
||||||
"""Set private attributes."""
|
"""Set private attributes."""
|
||||||
@@ -357,10 +374,7 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
if self.agents:
|
if self.agents:
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
if self.cache:
|
agent.configure_executor(self._cache_handler, self._rpm_controller)
|
||||||
agent.set_cache_handler(self._cache_handler)
|
|
||||||
if self.max_rpm:
|
|
||||||
agent.set_rpm_controller(self._rpm_controller)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -627,7 +641,7 @@ class Crew(BaseModel):
|
|||||||
for after_callback in self.after_kickoff_callbacks:
|
for after_callback in self.after_kickoff_callbacks:
|
||||||
result = after_callback(result)
|
result = after_callback(result)
|
||||||
|
|
||||||
metrics += [agent._token_process.get_summary() for agent in self.agents]
|
metrics += [agent.token_process.get_summary() for agent in self.agents]
|
||||||
|
|
||||||
self.usage_metrics = UsageMetrics()
|
self.usage_metrics = UsageMetrics()
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
@@ -1174,19 +1188,22 @@ class Crew(BaseModel):
|
|||||||
agent.interpolate_inputs(inputs)
|
agent.interpolate_inputs(inputs)
|
||||||
|
|
||||||
def _finish_execution(self, final_string_output: str) -> None:
|
def _finish_execution(self, final_string_output: str) -> None:
|
||||||
if self.max_rpm:
|
if self._rpm_controller:
|
||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||||
"""Calculates and returns the usage metrics."""
|
"""Calculates and returns the usage metrics."""
|
||||||
total_usage_metrics = UsageMetrics()
|
total_usage_metrics = UsageMetrics()
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
if hasattr(agent, "_token_process"):
|
# Directly access token_process since it's now a field in BaseAgent
|
||||||
token_sum = agent._token_process.get_summary()
|
token_sum = agent.token_process.get_summary()
|
||||||
total_usage_metrics.add_usage_metrics(token_sum)
|
|
||||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
|
||||||
token_sum = self.manager_agent._token_process.get_summary()
|
|
||||||
total_usage_metrics.add_usage_metrics(token_sum)
|
total_usage_metrics.add_usage_metrics(token_sum)
|
||||||
|
|
||||||
|
if self.manager_agent:
|
||||||
|
# Directly access token_process since it's now a field in BaseAgent
|
||||||
|
token_sum = self.manager_agent.token_process.get_summary()
|
||||||
|
total_usage_metrics.add_usage_metrics(token_sum)
|
||||||
|
|
||||||
self.usage_metrics = total_usage_metrics
|
self.usage_metrics = total_usage_metrics
|
||||||
return total_usage_metrics
|
return total_usage_metrics
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Callable, Type, get_args, get_origin
|
from typing import Any, Callable, Optional, Type, get_args, get_origin
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -19,11 +19,21 @@ from crewai.tools.structured_tool import CrewStructuredTool
|
|||||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||||
|
|
||||||
|
|
||||||
|
# Define a helper function with an explicit signature
|
||||||
|
def default_cache_function(
|
||||||
|
_args: Optional[Any] = None, _result: Optional[Any] = None
|
||||||
|
) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(BaseModel, ABC):
|
class BaseTool(BaseModel, ABC):
|
||||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = ConfigDict()
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
from_attributes=True, # Allow conversion from ORM objects
|
||||||
|
)
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
"""The unique name of the tool that clearly communicates its purpose."""
|
"""The unique name of the tool that clearly communicates its purpose."""
|
||||||
@@ -33,8 +43,10 @@ class BaseTool(BaseModel, ABC):
|
|||||||
"""The schema for the arguments that the tool accepts."""
|
"""The schema for the arguments that the tool accepts."""
|
||||||
description_updated: bool = False
|
description_updated: bool = False
|
||||||
"""Flag to check if the description has been updated."""
|
"""Flag to check if the description has been updated."""
|
||||||
cache_function: Callable = lambda _args=None, _result=None: True
|
cache_function: Callable[[Optional[Any], Optional[Any]], bool] = (
|
||||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
default_cache_function
|
||||||
|
)
|
||||||
|
"""Function used to determine if the tool should be cached."""
|
||||||
result_as_answer: bool = False
|
result_as_answer: bool = False
|
||||||
"""Flag to check if the tool should be the final agent answer."""
|
"""Flag to check if the tool should be the final agent answer."""
|
||||||
|
|
||||||
@@ -177,74 +189,43 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
return origin.__name__
|
return origin.__name__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get(self) -> Callable[[str, Any], Any]:
|
||||||
|
# Instead of an inline lambda, we define a helper function with explicit types.
|
||||||
|
def _getter(key: str, default: Any = None) -> Any:
|
||||||
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
return _getter
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool):
|
class Tool(BaseTool):
|
||||||
"""The function that will be executed when the tool is called."""
|
"""Tool implementation that requires a function."""
|
||||||
|
|
||||||
func: Callable
|
func: Callable
|
||||||
|
model_config = ConfigDict(
|
||||||
|
arbitrary_types_allowed=True,
|
||||||
|
from_attributes=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
def to_langchain(self) -> Any:
|
||||||
def from_langchain(cls, tool: Any) -> "Tool":
|
"""Convert to a LangChain-compatible tool."""
|
||||||
"""Create a Tool instance from a CrewStructuredTool.
|
try:
|
||||||
|
from langchain_core.tools import Tool as LC_Tool
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("langchain_core is not installed")
|
||||||
|
|
||||||
This method takes a CrewStructuredTool object and converts it into a
|
# Use self._run (which is bound and calls self.func) so that the LC_Tool gets proper attributes.
|
||||||
Tool instance. It ensures that the provided tool has a callable 'func'
|
return LC_Tool(
|
||||||
attribute and infers the argument schema if not explicitly provided.
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
Args:
|
func=self._run,
|
||||||
tool (Any): The CrewStructuredTool object to be converted.
|
args_schema=self.args_schema,
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tool: A new Tool instance created from the provided CrewStructuredTool.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the provided tool does not have a callable 'func' attribute.
|
|
||||||
"""
|
|
||||||
if not hasattr(tool, "func") or not callable(tool.func):
|
|
||||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
|
||||||
|
|
||||||
args_schema = getattr(tool, "args_schema", None)
|
|
||||||
|
|
||||||
if args_schema is None:
|
|
||||||
# Infer args_schema from the function signature if not provided
|
|
||||||
func_signature = signature(tool.func)
|
|
||||||
annotations = func_signature.parameters
|
|
||||||
args_fields = {}
|
|
||||||
for name, param in annotations.items():
|
|
||||||
if name != "self":
|
|
||||||
param_annotation = (
|
|
||||||
param.annotation if param.annotation != param.empty else Any
|
|
||||||
)
|
|
||||||
field_info = Field(
|
|
||||||
default=...,
|
|
||||||
description="",
|
|
||||||
)
|
|
||||||
args_fields[name] = (param_annotation, field_info)
|
|
||||||
if args_fields:
|
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
|
||||||
else:
|
|
||||||
# Create a default schema with no fields if no parameters are found
|
|
||||||
args_schema = create_model(
|
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
name=getattr(tool, "name", "Unnamed Tool"),
|
|
||||||
description=getattr(tool, "description", ""),
|
|
||||||
func=tool.func,
|
|
||||||
args_schema=args_schema,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_langchain(
|
|
||||||
tools: list[BaseTool | CrewStructuredTool],
|
|
||||||
) -> list[CrewStructuredTool]:
|
|
||||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
|
||||||
|
|
||||||
|
|
||||||
def tool(*args):
|
def tool(*args):
|
||||||
"""
|
"""
|
||||||
Decorator to create a tool from a function.
|
Decorator to create a tool from a function.
|
||||||
|
|||||||
@@ -1,15 +1,52 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Optional
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.utils import Usage
|
from litellm.types.utils import Usage
|
||||||
|
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
|
|
||||||
|
|
||||||
class TokenCalcHandler(CustomLogger):
|
class AbstractTokenCounter(ABC):
|
||||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
"""
|
||||||
self.token_cost_process = token_cost_process
|
Abstract base class for token counting callbacks.
|
||||||
|
Implementations should track token usage from different LLM providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||||
|
"""Initialize with a TokenProcess instance to track tokens."""
|
||||||
|
self.token_process = token_process
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||||
|
"""Update token usage counts in the token process."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||||
|
"""
|
||||||
|
Token counter implementation for LiteLLM.
|
||||||
|
Uses LiteLLM's CustomLogger interface to track token usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||||
|
AbstractTokenCounter.__init__(self, token_process)
|
||||||
|
CustomLogger.__init__(self)
|
||||||
|
|
||||||
|
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||||
|
"""Update token usage counts in the token process."""
|
||||||
|
if self.token_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if prompt_tokens > 0:
|
||||||
|
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||||
|
|
||||||
|
if completion_tokens > 0:
|
||||||
|
self.token_process.sum_completion_tokens(completion_tokens)
|
||||||
|
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
def log_success_event(
|
def log_success_event(
|
||||||
self,
|
self,
|
||||||
@@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger):
|
|||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.token_cost_process is None:
|
"""
|
||||||
|
Process successful LLM call and extract token usage information.
|
||||||
|
This method is called by LiteLLM after a successful completion.
|
||||||
|
"""
|
||||||
|
if self.token_process is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -26,18 +67,159 @@ class TokenCalcHandler(CustomLogger):
|
|||||||
if isinstance(response_obj, dict) and "usage" in response_obj:
|
if isinstance(response_obj, dict) and "usage" in response_obj:
|
||||||
usage: Usage = response_obj["usage"]
|
usage: Usage = response_obj["usage"]
|
||||||
if usage:
|
if usage:
|
||||||
self.token_cost_process.sum_successful_requests(1)
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
|
||||||
if hasattr(usage, "prompt_tokens"):
|
if hasattr(usage, "prompt_tokens"):
|
||||||
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
prompt_tokens = usage.prompt_tokens
|
||||||
|
elif isinstance(usage, dict) and "prompt_tokens" in usage:
|
||||||
|
prompt_tokens = usage["prompt_tokens"]
|
||||||
|
|
||||||
if hasattr(usage, "completion_tokens"):
|
if hasattr(usage, "completion_tokens"):
|
||||||
self.token_cost_process.sum_completion_tokens(
|
completion_tokens = usage.completion_tokens
|
||||||
usage.completion_tokens
|
elif isinstance(usage, dict) and "completion_tokens" in usage:
|
||||||
)
|
completion_tokens = usage["completion_tokens"]
|
||||||
|
|
||||||
|
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
# Handle cached tokens if available
|
||||||
if (
|
if (
|
||||||
hasattr(usage, "prompt_tokens_details")
|
hasattr(usage, "prompt_tokens_details")
|
||||||
and usage.prompt_tokens_details
|
and usage.prompt_tokens_details
|
||||||
and usage.prompt_tokens_details.cached_tokens
|
and usage.prompt_tokens_details.cached_tokens
|
||||||
):
|
):
|
||||||
self.token_cost_process.sum_cached_prompt_tokens(
|
self.token_process.sum_cached_prompt_tokens(
|
||||||
usage.prompt_tokens_details.cached_tokens
|
usage.prompt_tokens_details.cached_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
|
||||||
|
"""
|
||||||
|
Token counter implementation for LangChain.
|
||||||
|
Implements the necessary callback methods to track token usage from LangChain responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||||
|
BaseCallbackHandler.__init__(self)
|
||||||
|
AbstractTokenCounter.__init__(self, token_process)
|
||||||
|
|
||||||
|
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||||
|
"""Update token usage counts in the token process."""
|
||||||
|
if self.token_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if prompt_tokens > 0:
|
||||||
|
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||||
|
|
||||||
|
if completion_tokens > 0:
|
||||||
|
self.token_process.sum_completion_tokens(completion_tokens)
|
||||||
|
|
||||||
|
self.token_process.sum_successful_requests(1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_llm(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chain(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_agent(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chat_model(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_retriever(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_tools(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Called when LLM starts processing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
"""Called when LLM generates a new token."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||||
|
"""
|
||||||
|
Called when LLM ends processing.
|
||||||
|
Extracts token usage from LangChain response objects.
|
||||||
|
"""
|
||||||
|
if self.token_process is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle LangChain response format
|
||||||
|
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
|
||||||
|
token_usage = response.llm_output.get("token_usage", {})
|
||||||
|
|
||||||
|
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when LLM errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Called when a chain starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Called when a chain ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when a chain errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Called when a tool starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
|
"""Called when a tool ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when a tool errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
"""Called when text is generated."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Called when an agent starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
|
||||||
|
"""Called when an agent ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
|
"""Called when an agent errors."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# For backward compatibility
|
||||||
|
class TokenCalcHandler(LiteLLMTokenCounter):
|
||||||
|
"""
|
||||||
|
Alias for LiteLLMTokenCounter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|||||||
@@ -547,6 +547,7 @@ def test_crew_with_delegating_agents():
|
|||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_crew_with_delegating_agents_should_not_override_task_tools():
|
def test_crew_with_delegating_agents_should_not_override_task_tools():
|
||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ def test_llm_callback_replacement():
|
|||||||
llm1 = LLM(model="gpt-4o-mini")
|
llm1 = LLM(model="gpt-4o-mini")
|
||||||
llm2 = LLM(model="gpt-4o-mini")
|
llm2 = LLM(model="gpt-4o-mini")
|
||||||
|
|
||||||
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
|
calc_handler_1 = TokenCalcHandler(token_process=TokenProcess())
|
||||||
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
|
calc_handler_2 = TokenCalcHandler(token_process=TokenProcess())
|
||||||
|
|
||||||
result1 = llm1.call(
|
result1 = llm1.call(
|
||||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
callbacks=[calc_handler_1],
|
callbacks=[calc_handler_1],
|
||||||
)
|
)
|
||||||
print("result1:", result1)
|
print("result1:", result1)
|
||||||
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
|
usage_metrics_1 = calc_handler_1.token_process.get_summary()
|
||||||
print("usage_metrics_1:", usage_metrics_1)
|
print("usage_metrics_1:", usage_metrics_1)
|
||||||
|
|
||||||
result2 = llm2.call(
|
result2 = llm2.call(
|
||||||
@@ -35,13 +35,13 @@ def test_llm_callback_replacement():
|
|||||||
)
|
)
|
||||||
sleep(5)
|
sleep(5)
|
||||||
print("result2:", result2)
|
print("result2:", result2)
|
||||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
usage_metrics_2 = calc_handler_2.token_process.get_summary()
|
||||||
print("usage_metrics_2:", usage_metrics_2)
|
print("usage_metrics_2:", usage_metrics_2)
|
||||||
|
|
||||||
# The first handler should not have been updated
|
# The first handler should not have been updated
|
||||||
assert usage_metrics_1.successful_requests == 1
|
assert usage_metrics_1.successful_requests == 1
|
||||||
assert usage_metrics_2.successful_requests == 1
|
assert usage_metrics_2.successful_requests == 1
|
||||||
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
assert usage_metrics_1 == calc_handler_1.token_process.get_summary()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -57,14 +57,14 @@ def test_llm_call_with_string_input():
|
|||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_llm_call_with_string_input_and_callbacks():
|
def test_llm_call_with_string_input_and_callbacks():
|
||||||
llm = LLM(model="gpt-4o-mini")
|
llm = LLM(model="gpt-4o-mini")
|
||||||
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
calc_handler = TokenCalcHandler(token_process=TokenProcess())
|
||||||
|
|
||||||
# Test the call method with a string input and callbacks
|
# Test the call method with a string input and callbacks
|
||||||
result = llm.call(
|
result = llm.call(
|
||||||
"Tell me a joke.",
|
"Tell me a joke.",
|
||||||
callbacks=[calc_handler],
|
callbacks=[calc_handler],
|
||||||
)
|
)
|
||||||
usage_metrics = calc_handler.token_cost_process.get_summary()
|
usage_metrics = calc_handler.token_process.get_summary()
|
||||||
|
|
||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
assert len(result.strip()) > 0
|
assert len(result.strip()) > 0
|
||||||
@@ -285,6 +285,7 @@ def test_o3_mini_reasoning_effort_medium():
|
|||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
assert "Paris" in result
|
assert "Paris" in result
|
||||||
|
|
||||||
|
|
||||||
def test_context_window_validation():
|
def test_context_window_validation():
|
||||||
"""Test that context window validation works correctly."""
|
"""Test that context window validation works correctly."""
|
||||||
# Test valid window size
|
# Test valid window size
|
||||||
|
|||||||
189
tests/utilities/test_token_tracking.py
Normal file
189
tests/utilities/test_token_tracking.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Test module for token tracking functionality in CrewAI.
|
||||||
|
This tests both direct LangChain models and LiteLLM integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.tools import Tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from crewai import Crew, Process, Task
|
||||||
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
|
from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter
|
||||||
|
from crewai.utilities.token_counter_callback import (
|
||||||
|
LangChainTokenCounter,
|
||||||
|
LiteLLMTokenCounter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_weather(location: str = "San Francisco"):
|
||||||
|
"""Simulates fetching current weather data for a given location."""
|
||||||
|
# In a real implementation, you could replace this with an API call.
|
||||||
|
return f"Current weather in {location}: Sunny, 25°C"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenTracking:
|
||||||
|
"""Test suite for token tracking functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def weather_tool(self):
|
||||||
|
"""Create a simple weather tool for testing."""
|
||||||
|
return Tool(
|
||||||
|
name="Weather",
|
||||||
|
func=get_weather,
|
||||||
|
description="Useful for fetching current weather information for a given location.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_response(self):
|
||||||
|
"""Create a mock OpenAI response with token usage information."""
|
||||||
|
return {
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"total_tokens": 150,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_token_process_basic(self):
|
||||||
|
"""Test basic functionality of TokenProcess class."""
|
||||||
|
token_process = TokenProcess()
|
||||||
|
|
||||||
|
# Test adding prompt tokens
|
||||||
|
token_process.sum_prompt_tokens(100)
|
||||||
|
assert token_process.prompt_tokens == 100
|
||||||
|
|
||||||
|
# Test adding completion tokens
|
||||||
|
token_process.sum_completion_tokens(50)
|
||||||
|
assert token_process.completion_tokens == 50
|
||||||
|
|
||||||
|
# Test adding successful requests
|
||||||
|
token_process.sum_successful_requests(1)
|
||||||
|
assert token_process.successful_requests == 1
|
||||||
|
|
||||||
|
# Test getting summary
|
||||||
|
summary = token_process.get_summary()
|
||||||
|
assert summary.prompt_tokens == 100
|
||||||
|
assert summary.completion_tokens == 50
|
||||||
|
assert summary.total_tokens == 150
|
||||||
|
assert summary.successful_requests == 1
|
||||||
|
|
||||||
|
@patch("litellm.completion")
|
||||||
|
def test_litellm_token_counter(self, mock_completion):
|
||||||
|
"""Test LiteLLMTokenCounter with a mock response."""
|
||||||
|
# Setup
|
||||||
|
token_process = TokenProcess()
|
||||||
|
counter = LiteLLMTokenCounter(token_process)
|
||||||
|
|
||||||
|
# Mock the response
|
||||||
|
mock_completion.return_value = {
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simulate a successful LLM call
|
||||||
|
counter.log_success_event(
|
||||||
|
kwargs={},
|
||||||
|
response_obj=mock_completion.return_value,
|
||||||
|
start_time=0,
|
||||||
|
end_time=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify token counts were updated
|
||||||
|
assert token_process.prompt_tokens == 100
|
||||||
|
assert token_process.completion_tokens == 50
|
||||||
|
assert token_process.successful_requests == 1
|
||||||
|
|
||||||
|
def test_langchain_token_counter(self):
|
||||||
|
"""Test LangChainTokenCounter with a mock response."""
|
||||||
|
# Setup
|
||||||
|
token_process = TokenProcess()
|
||||||
|
counter = LangChainTokenCounter(token_process)
|
||||||
|
|
||||||
|
# Create a mock LangChain response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.llm_output = {
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simulate a successful LLM call
|
||||||
|
counter.on_llm_end(mock_response)
|
||||||
|
|
||||||
|
# Verify token counts were updated
|
||||||
|
assert token_process.prompt_tokens == 100
|
||||||
|
assert token_process.completion_tokens == 50
|
||||||
|
assert token_process.successful_requests == 1
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("OPENAI_API_KEY"),
|
||||||
|
reason="OPENAI_API_KEY environment variable not set",
|
||||||
|
)
|
||||||
|
def test_langchain_agent_adapter_token_tracking(self, weather_tool):
|
||||||
|
"""
|
||||||
|
Integration test for token tracking with LangChainAgentAdapter.
|
||||||
|
This test requires an OpenAI API key.
|
||||||
|
"""
|
||||||
|
# Skip if LangGraph is not installed
|
||||||
|
try:
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("LangGraph is not installed. Install it with: uv add langgraph")
|
||||||
|
|
||||||
|
# Initialize a ChatOpenAI model
|
||||||
|
llm = ChatOpenAI(model="gpt-4o")
|
||||||
|
|
||||||
|
# Create a LangChainAgentAdapter with the direct LLM
|
||||||
|
agent = LangChainAgentAdapter(
|
||||||
|
langchain_agent=llm,
|
||||||
|
tools=[weather_tool],
|
||||||
|
role="Weather Agent",
|
||||||
|
goal="Provide current weather information for the requested location.",
|
||||||
|
backstory="An expert weather provider that fetches current weather information using simulated data.",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a weather task for the agent
|
||||||
|
task = Task(
|
||||||
|
description="Fetch the current weather for San Francisco.",
|
||||||
|
expected_output="A weather report showing current conditions in San Francisco.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a crew with the single agent and task
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=True,
|
||||||
|
process=Process.sequential,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the crew
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Verify token usage was tracked
|
||||||
|
assert result.token_usage is not None
|
||||||
|
assert result.token_usage.total_tokens > 0
|
||||||
|
assert result.token_usage.prompt_tokens > 0
|
||||||
|
assert result.token_usage.completion_tokens > 0
|
||||||
|
assert result.token_usage.successful_requests > 0
|
||||||
|
|
||||||
|
# Also verify token usage directly from the agent
|
||||||
|
usage = agent.token_process.get_summary()
|
||||||
|
assert usage.prompt_tokens > 0
|
||||||
|
assert usage.completion_tokens > 0
|
||||||
|
assert usage.total_tokens > 0
|
||||||
|
assert usage.successful_requests > 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-xvs", __file__])
|
||||||
Reference in New Issue
Block a user