mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
52 Commits
devin/1764
...
brandon/br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6efee89399 | ||
|
|
75d8e086a4 | ||
|
|
ef48cbe971 | ||
|
|
88d8079dcd | ||
|
|
33ef612cd5 | ||
|
|
6c81acac00 | ||
|
|
2181166a62 | ||
|
|
3f17789152 | ||
|
|
ed0b8e1563 | ||
|
|
4915420ed2 | ||
|
|
271faa917b | ||
|
|
e87eb57edc | ||
|
|
e6e4bb15d7 | ||
|
|
92a3349d64 | ||
|
|
940f0647a9 | ||
|
|
5db512bef6 | ||
|
|
48d2b8c320 | ||
|
|
f6b0b492a4 | ||
|
|
4b63b29787 | ||
|
|
860efc3b42 | ||
|
|
70ab4ad003 | ||
|
|
6fb25a1af7 | ||
|
|
ca9277ae4c | ||
|
|
e58e544304 | ||
|
|
84e0a9e686 | ||
|
|
6e0e9b30fe | ||
|
|
f6393fd088 | ||
|
|
64804682fc | ||
|
|
cc7669ab39 | ||
|
|
377b64ac81 | ||
|
|
470254c3e2 | ||
|
|
497190f823 | ||
|
|
241adb8ed0 | ||
|
|
e60c6e66a4 | ||
|
|
afd01e3c0c | ||
|
|
81b8ae0abd | ||
|
|
ae19437473 | ||
|
|
f8c74b4fbb | ||
|
|
134e7ab241 | ||
|
|
74e63621a5 | ||
|
|
8953af6133 | ||
|
|
2b438baad4 | ||
|
|
185556b7e3 | ||
|
|
1e23d37a14 | ||
|
|
df21f01441 | ||
|
|
b957fc1a18 | ||
|
|
fd0e1bdd1a | ||
|
|
265b37316b | ||
|
|
ff32880a54 | ||
|
|
a38483e1b4 | ||
|
|
7910dc9337 | ||
|
|
796e50aba8 |
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Union, cast
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -170,27 +170,19 @@ class Agent(BaseAgent):
|
||||
Output of the agent
|
||||
"""
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalling")
|
||||
|
||||
self.tools_handler.last_used_tool = {} # type: ignore # Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "ToolCalli
|
||||
task_prompt = task.prompt()
|
||||
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
if task.output_json or task.output_pydantic:
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
# schema = json.dumps(task.output_json, indent=2)
|
||||
schema = generate_model_description(task.output_json)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema = generate_model_description(task.output_pydantic)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
# Choose the output format, preferring output_json if available
|
||||
output_format = (
|
||||
task.output_json if task.output_json else task.output_pydantic
|
||||
)
|
||||
schema = generate_model_description(cast(type, output_format))
|
||||
task_prompt += f"\n{self.i18n.slice('formatted_task_instructions').format(output_format=schema)}"
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
@@ -276,9 +268,6 @@ class Agent(BaseAgent):
|
||||
raise e
|
||||
result = self.execute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
@@ -338,7 +327,7 @@ class Agent(BaseAgent):
|
||||
request_within_rpm_limit=(
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
),
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
callbacks=[TokenCalcHandler(self.token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
|
||||
@@ -73,20 +73,27 @@ class BaseAgent(ABC, BaseModel):
|
||||
Increment formatting errors.
|
||||
copy() -> "BaseAgent":
|
||||
Create a copy of the agent.
|
||||
set_rpm_controller(rpm_controller: RPMController) -> None:
|
||||
set_rpm_controller(rpm_controller: Optional[RPMController] = None) -> None:
|
||||
Set the rpm controller for the agent.
|
||||
set_private_attrs() -> "BaseAgent":
|
||||
Set private attributes.
|
||||
configure_executor(cache_handler: CacheHandler, rpm_controller: RPMController) -> None:
|
||||
Configure the agent's executor with both cache and RPM handling.
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
|
||||
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
formatting_errors: int = Field(
|
||||
default=0, description="Number of formatting errors."
|
||||
@@ -196,8 +203,6 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
return self
|
||||
|
||||
@@ -217,8 +222,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -266,7 +270,7 @@ class BaseAgent(ABC, BaseModel):
|
||||
"_logger",
|
||||
"_rpm_controller",
|
||||
"_request_within_rpm_limit",
|
||||
"_token_process",
|
||||
"token_process",
|
||||
"agent_executor",
|
||||
"tools",
|
||||
"tools_handler",
|
||||
@@ -337,20 +341,49 @@ class BaseAgent(ABC, BaseModel):
|
||||
if self.cache:
|
||||
self.cache_handler = cache_handler
|
||||
self.tools_handler.cache = cache_handler
|
||||
self.create_agent_executor()
|
||||
# Only create the executor if it hasn't been created yet.
|
||||
if self.agent_executor is None:
|
||||
self.create_agent_executor()
|
||||
|
||||
def increment_formatting_errors(self) -> None:
|
||||
self.formatting_errors += 1
|
||||
|
||||
def set_rpm_controller(self, rpm_controller: RPMController) -> None:
|
||||
"""Set the rpm controller for the agent.
|
||||
|
||||
Args:
|
||||
rpm_controller: An instance of the RPMController class.
|
||||
def set_rpm_controller(
|
||||
self, rpm_controller: Optional[RPMController] = None
|
||||
) -> None:
|
||||
"""
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
Set the RPM controller for the agent. If no rpm_controller is provided, then:
|
||||
- use self.max_rpm if set, or
|
||||
- if self.crew exists and has max_rpm, use that.
|
||||
"""
|
||||
if self._rpm_controller is None:
|
||||
if rpm_controller is not None:
|
||||
self._rpm_controller = rpm_controller
|
||||
elif self.max_rpm:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
elif self.crew and getattr(self.crew, "max_rpm", None):
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.crew.max_rpm, logger=self._logger
|
||||
)
|
||||
# else: no rpm limit provided – leave the controller None
|
||||
if self.agent_executor is None:
|
||||
self.create_agent_executor()
|
||||
|
||||
def configure_executor(
|
||||
self, cache_handler: CacheHandler, rpm_controller: Optional[RPMController]
|
||||
) -> None:
|
||||
"""Configure the agent's executor with both cache and RPM handling.
|
||||
|
||||
This method delegates to set_cache_handler and set_rpm_controller, applying the configuration
|
||||
only if the respective flags or values are set.
|
||||
"""
|
||||
if self.cache:
|
||||
self.set_cache_handler(cache_handler)
|
||||
# Use the injected RPM controller rather than auto-creating one
|
||||
if rpm_controller:
|
||||
self.set_rpm_controller(rpm_controller)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
pass
|
||||
|
||||
@@ -88,7 +88,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
tool.name: tool for tool in self.tools
|
||||
}
|
||||
self.stop = stop_words
|
||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
||||
self.llm.stop = list(set((self.llm.stop or []) + self.stop))
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
|
||||
468
src/crewai/agents/langchain_agent_adapter.py
Normal file
468
src/crewai/agents/langchain_agent_adapter.py
Normal file
@@ -0,0 +1,468 @@
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.token_counter_callback import (
|
||||
LangChainTokenCounter,
|
||||
LiteLLMTokenCounter,
|
||||
)
|
||||
|
||||
|
||||
class LangChainAgentAdapter(BaseAgent):
|
||||
"""
|
||||
Adapter class to wrap a LangChain agent and make it compatible with CrewAI's BaseAgent interface.
|
||||
|
||||
Note:
|
||||
- This adapter does not require LangChain as a dependency.
|
||||
- It wraps an external LangChain agent (passed as any type) and delegates calls
|
||||
such as execute_task() to the LangChain agent's invoke() method.
|
||||
- Extended logic is added to build prompts, incorporate memory, knowledge, training hints,
|
||||
and now a human feedback loop similar to what is done in CrewAgentExecutor.
|
||||
"""
|
||||
|
||||
langchain_agent: Any = Field(
|
||||
...,
|
||||
description="The wrapped LangChain runnable agent instance. It is expected to have an 'invoke' method.",
|
||||
)
|
||||
tools: Optional[List[Union[BaseTool, Any]]] = Field(
|
||||
default_factory=list,
|
||||
description="Tools at the agent's disposal. Accepts both CrewAI BaseTool instances and other tools.",
|
||||
)
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
default=None, description="Optional function calling LLM."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None,
|
||||
description="Callback executed after each step of agent execution.",
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
)
|
||||
multimodal: bool = Field(
|
||||
default=False, description="Whether the agent is multimodal."
|
||||
)
|
||||
i18n: Any = None
|
||||
crew: Any = None
|
||||
knowledge: Any = None
|
||||
token_process: TokenProcess = Field(default_factory=TokenProcess, exclude=True)
|
||||
token_callback: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
def convert_tools(cls, value):
|
||||
"""Ensure tools are valid CrewAI BaseTool instances."""
|
||||
if not value:
|
||||
return value
|
||||
new_tools = []
|
||||
for tool in value:
|
||||
# If tool is already a CrewAI BaseTool instance, keep it as is.
|
||||
if isinstance(tool, BaseTool):
|
||||
new_tools.append(tool)
|
||||
else:
|
||||
new_tools.append(Tool.from_langchain(tool))
|
||||
return new_tools
|
||||
|
||||
def _extract_text(self, message: Any) -> str:
|
||||
"""
|
||||
Helper to extract plain text from a message object.
|
||||
This checks if the message is a dict with a "content" key, or has a "content" attribute,
|
||||
or if it's a tuple from LangGraph's message format.
|
||||
"""
|
||||
# Handle LangGraph message tuple format (role, content)
|
||||
if isinstance(message, tuple) and len(message) == 2:
|
||||
return str(message[1])
|
||||
|
||||
# Handle dictionary with content key
|
||||
elif isinstance(message, dict):
|
||||
if "content" in message:
|
||||
return message["content"]
|
||||
# Handle LangGraph message format with additional metadata
|
||||
elif "messages" in message and message["messages"]:
|
||||
last_message = message["messages"][-1]
|
||||
if isinstance(last_message, tuple) and len(last_message) == 2:
|
||||
return str(last_message[1])
|
||||
return self._extract_text(last_message)
|
||||
|
||||
# Handle object with content attribute
|
||||
elif hasattr(message, "content") and isinstance(
|
||||
getattr(message, "content"), str
|
||||
):
|
||||
return getattr(message, "content")
|
||||
|
||||
# Handle string directly
|
||||
elif isinstance(message, str):
|
||||
return message
|
||||
|
||||
# Default fallback
|
||||
return str(message)
|
||||
|
||||
def _register_token_callback(self):
|
||||
"""
|
||||
Register the appropriate token counter callback with the language model.
|
||||
This method handles different types of models (LiteLLM, LangChain, direct LLMs)
|
||||
and different callback structures.
|
||||
"""
|
||||
# Skip if we already have a token callback registered
|
||||
if self.token_callback is not None:
|
||||
return
|
||||
|
||||
# Skip if we don't have a token_process attribute
|
||||
if not hasattr(self, "token_process"):
|
||||
return
|
||||
|
||||
# Determine if we're using LiteLLM or LangChain based on the agent type
|
||||
if hasattr(self.langchain_agent, "client") and hasattr(
|
||||
self.langchain_agent.client, "callbacks"
|
||||
):
|
||||
# This is likely a LiteLLM-based agent
|
||||
self.token_callback = LiteLLMTokenCounter(self.token_process)
|
||||
|
||||
# Add our callback to the LLM directly
|
||||
if isinstance(self.langchain_agent.client.callbacks, list):
|
||||
if self.token_callback not in self.langchain_agent.client.callbacks:
|
||||
self.langchain_agent.client.callbacks.append(self.token_callback)
|
||||
else:
|
||||
self.langchain_agent.client.callbacks = [self.token_callback]
|
||||
else:
|
||||
# This is likely a LangChain-based agent
|
||||
self.token_callback = LangChainTokenCounter(self.token_process)
|
||||
|
||||
# Add callback to the LangChain model
|
||||
if hasattr(self.langchain_agent, "callbacks"):
|
||||
if self.langchain_agent.callbacks is None:
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
# For direct LLM models
|
||||
elif hasattr(self.langchain_agent, "llm") and hasattr(
|
||||
self.langchain_agent.llm, "callbacks"
|
||||
):
|
||||
if self.langchain_agent.llm.callbacks is None:
|
||||
self.langchain_agent.llm.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.llm.callbacks, list):
|
||||
self.langchain_agent.llm.callbacks.append(self.token_callback)
|
||||
# Direct LLM case
|
||||
elif not hasattr(self.langchain_agent, "agent"):
|
||||
# This might be a direct LLM, not an agent
|
||||
if (
|
||||
not hasattr(self.langchain_agent, "callbacks")
|
||||
or self.langchain_agent.callbacks is None
|
||||
):
|
||||
self.langchain_agent.callbacks = [self.token_callback]
|
||||
elif isinstance(self.langchain_agent.callbacks, list):
|
||||
self.langchain_agent.callbacks.append(self.token_callback)
|
||||
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Execute a task by building the full task prompt (with memory, knowledge, tool instructions,
|
||||
and training hints) then delegating execution to the wrapped LangChain agent.
|
||||
If the task requires human input, a feedback loop is run that mimics the CrewAgentExecutor.
|
||||
"""
|
||||
task_prompt = task.prompt()
|
||||
|
||||
if task.output_json or task.output_pydantic:
|
||||
# Choose the output format, preferring output_json if available
|
||||
output_format = (
|
||||
task.output_json if task.output_json else task.output_pydantic
|
||||
)
|
||||
schema = generate_model_description(cast(type, output_format))
|
||||
instruction = self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
task_prompt += f"\n{instruction}"
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
|
||||
if self.crew and self.crew.memory:
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew.memory_config,
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._user_memory,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip():
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
|
||||
if agent_knowledge_snippets:
|
||||
from crewai.knowledge.utils.knowledge_utils import (
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent_knowledge_context:
|
||||
task_prompt += agent_knowledge_context
|
||||
|
||||
if self.crew:
|
||||
knowledge_snippets = self.crew.query_knowledge([task.prompt()])
|
||||
if knowledge_snippets:
|
||||
from crewai.knowledge.utils.knowledge_utils import (
|
||||
extract_knowledge_context,
|
||||
)
|
||||
|
||||
crew_knowledge_context = extract_knowledge_context(knowledge_snippets)
|
||||
if crew_knowledge_context:
|
||||
task_prompt += crew_knowledge_context
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools)
|
||||
|
||||
self._show_start_logs(task)
|
||||
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Register token tracking callback
|
||||
self._register_token_callback()
|
||||
|
||||
init_state = {"messages": [("user", task_prompt)]}
|
||||
|
||||
# Estimate input tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters (better than word count)
|
||||
estimated_prompt_tokens = len(task_prompt) // 4 # ~4 chars per token
|
||||
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||
|
||||
state = self.agent_executor.invoke(init_state)
|
||||
|
||||
# Extract output from state based on its structure
|
||||
if "structured_response" in state:
|
||||
current_output = state["structured_response"]
|
||||
elif "messages" in state and state["messages"]:
|
||||
last_message = state["messages"][-1]
|
||||
current_output = self._extract_text(last_message)
|
||||
elif "output" in state:
|
||||
current_output = str(state["output"])
|
||||
else:
|
||||
# Fallback to extracting text from the entire state
|
||||
current_output = self._extract_text(state)
|
||||
|
||||
# Estimate completion tokens for tracking if we don't have actual counts
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_completion_tokens = len(current_output) // 4 # ~4 chars per token
|
||||
self.token_process.sum_completion_tokens(estimated_completion_tokens)
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
if task.human_input:
|
||||
current_output = self._handle_human_feedback(current_output)
|
||||
|
||||
return current_output
|
||||
|
||||
def _handle_human_feedback(self, current_output: str) -> str:
|
||||
"""
|
||||
Implements a feedback loop that prompts the user for feedback and then instructs
|
||||
the underlying LangChain agent to regenerate its answer with the requested changes.
|
||||
Only the inner content of the output is displayed to the user.
|
||||
"""
|
||||
while True:
|
||||
print("\nAgent output:")
|
||||
# Print only the inner text extracted from current_output.
|
||||
print(self._extract_text(current_output))
|
||||
|
||||
feedback = input("\nEnter your feedback (or press Enter to accept): ")
|
||||
if not feedback.strip():
|
||||
break # No feedback provided, exit the loop
|
||||
|
||||
extracted_output = self._extract_text(current_output)
|
||||
new_prompt = (
|
||||
f"Below is your previous answer:\n"
|
||||
f"{extracted_output}\n\n"
|
||||
f"Based on the following feedback: '{feedback}', please regenerate your answer with the requested details. "
|
||||
f"Specifically, display 10 bullet points in each section. Provide the complete updated answer below.\n\n"
|
||||
f"Updated answer:"
|
||||
)
|
||||
|
||||
# Estimate input tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_prompt_tokens = len(new_prompt) // 4 # ~4 chars per token
|
||||
self.token_process.sum_prompt_tokens(estimated_prompt_tokens)
|
||||
|
||||
try:
|
||||
new_state = self.agent_executor.invoke(
|
||||
{"messages": [("user", new_prompt)]}
|
||||
)
|
||||
# Extract output from state based on its structure
|
||||
if "structured_response" in new_state:
|
||||
new_output = new_state["structured_response"]
|
||||
elif "messages" in new_state and new_state["messages"]:
|
||||
last_message = new_state["messages"][-1]
|
||||
new_output = self._extract_text(last_message)
|
||||
elif "output" in new_state:
|
||||
new_output = str(new_state["output"])
|
||||
else:
|
||||
# Fallback to extracting text from the entire state
|
||||
new_output = self._extract_text(new_state)
|
||||
|
||||
# Estimate completion tokens for tracking
|
||||
if hasattr(self, "token_process"):
|
||||
# Rough estimate based on characters
|
||||
estimated_completion_tokens = (
|
||||
len(new_output) // 4
|
||||
) # ~4 chars per token
|
||||
self.token_process.sum_completion_tokens(
|
||||
estimated_completion_tokens
|
||||
)
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
current_output = new_output
|
||||
except Exception as e:
|
||||
print("Error during re-invocation with feedback:", e)
|
||||
break
|
||||
|
||||
return current_output
|
||||
|
||||
def _generate_model_description(self, model: Any) -> str:
|
||||
"""
|
||||
Generates a string description (schema) for the expected output.
|
||||
This is a placeholder that should call the actual implementation.
|
||||
"""
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
return generate_model_description(model)
|
||||
|
||||
def _training_handler(self, task_prompt: str) -> str:
|
||||
"""
|
||||
Append training instructions from Crew data to the task prompt.
|
||||
"""
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
if data:
|
||||
agent_id = str(self.id)
|
||||
if data.get(agent_id):
|
||||
human_feedbacks = [
|
||||
i["human_feedback"] for i in data.get(agent_id, {}).values()
|
||||
]
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n "
|
||||
+ "\n - ".join(human_feedbacks)
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""
|
||||
Append pre-trained instructions from Crew data to the task prompt.
|
||||
"""
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
data = CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load()
|
||||
if data and (trained_data_output := data.get(getattr(self, "role", "default"))):
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
+ "\n - ".join(trained_data_output["suggestions"])
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def create_agent_executor(self, tools: Optional[List[BaseTool]] = None) -> None:
|
||||
"""
|
||||
Creates an agent executor using LangGraph's create_react_agent if given an LLM,
|
||||
or uses the provided language model directly.
|
||||
"""
|
||||
try:
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"LangGraph library not found. Please run `uv add langgraph` to add LangGraph support."
|
||||
) from e
|
||||
|
||||
# Ensure raw_tools is always a list.
|
||||
raw_tools: List[Any] = (
|
||||
tools
|
||||
if tools is not None
|
||||
else (self.tools if self.tools is not None else [])
|
||||
)
|
||||
# Fallback: if raw_tools is still empty, try to extract them from the wrapped LangChain agent.
|
||||
if not raw_tools:
|
||||
if hasattr(self.langchain_agent, "agent") and hasattr(
|
||||
self.langchain_agent.agent, "tools"
|
||||
):
|
||||
raw_tools = self.langchain_agent.agent.tools or []
|
||||
else:
|
||||
raw_tools = getattr(self.langchain_agent, "tools", []) or []
|
||||
|
||||
used_tools = []
|
||||
# Use the global CrewAI Tool class (imported at the module level)
|
||||
for tool in raw_tools:
|
||||
# If the tool is a CrewAI Tool, convert it to a LangChain compatible tool.
|
||||
if isinstance(tool, Tool):
|
||||
used_tools.append(tool.to_langchain())
|
||||
else:
|
||||
used_tools.append(tool)
|
||||
|
||||
# Sanitize the agent's role for the "name" field. The allowed pattern is ^[a-zA-Z0-9_-]+$
|
||||
import re
|
||||
|
||||
agent_role = getattr(self, "role", "agent")
|
||||
sanitized_role = re.sub(r"\s+", "_", agent_role)
|
||||
|
||||
# Register token tracking callback
|
||||
self._register_token_callback()
|
||||
|
||||
self.agent_executor = create_react_agent(
|
||||
model=self.langchain_agent,
|
||||
tools=used_tools,
|
||||
debug=getattr(self, "verbose", False),
|
||||
name=sanitized_role,
|
||||
)
|
||||
|
||||
def _parse_tools(self, tools: List[BaseTool]) -> List[BaseTool]:
|
||||
return tools
|
||||
|
||||
def get_delegation_tools(self, agents: List["BaseAgent"]) -> List[BaseTool]:
|
||||
return []
|
||||
|
||||
def get_output_converter(
|
||||
self,
|
||||
llm: Any,
|
||||
text: str,
|
||||
model: Optional[Type] = None,
|
||||
instructions: str = "",
|
||||
) -> Converter:
|
||||
return Converter(llm=llm, text=text, model=model, instructions=instructions)
|
||||
|
||||
def _show_start_logs(self, task: Task) -> None:
|
||||
if self.langchain_agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
# Check if the adapter or its crew is in verbose mode.
|
||||
verbose = self.verbose or (self.crew and getattr(self.crew, "verbose", False))
|
||||
if verbose:
|
||||
from crewai.utilities import Printer
|
||||
|
||||
printer = Printer()
|
||||
# Use the adapter's role (inherited from BaseAgent) for logging.
|
||||
printer.print(
|
||||
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{self.role}\033[00m"
|
||||
)
|
||||
description = getattr(task, "description", "Not Found")
|
||||
printer.print(
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{description}\033[00m"
|
||||
)
|
||||
@@ -94,7 +94,7 @@ class Crew(BaseModel):
|
||||
|
||||
__hash__ = object.__hash__ # type: ignore
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_rpm_controller: Optional[RPMController] = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
@@ -248,7 +248,6 @@ class Crew(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
self._cache_handler = CacheHandler()
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.output_log_file:
|
||||
self._file_handler = FileHandler(self.output_log_file)
|
||||
@@ -258,6 +257,24 @@ class Crew(BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_dependencies(self) -> "Crew":
|
||||
# Always create a cache handler, but it will only be used if self.cache is True
|
||||
# Create the Crew-level RPM controller if a max RPM is specified
|
||||
if self.max_rpm is not None:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=Logger(verbose=self.verbose)
|
||||
)
|
||||
else:
|
||||
self._rpm_controller = None
|
||||
|
||||
# Now inject these external dependencies into each agent
|
||||
for agent in self.agents:
|
||||
agent.crew = self # ensure the agent's crew reference is set
|
||||
agent.configure_executor(self._cache_handler, self._rpm_controller)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def create_crew_memory(self) -> "Crew":
|
||||
"""Set private attributes."""
|
||||
@@ -357,10 +374,7 @@ class Crew(BaseModel):
|
||||
|
||||
if self.agents:
|
||||
for agent in self.agents:
|
||||
if self.cache:
|
||||
agent.set_cache_handler(self._cache_handler)
|
||||
if self.max_rpm:
|
||||
agent.set_rpm_controller(self._rpm_controller)
|
||||
agent.configure_executor(self._cache_handler, self._rpm_controller)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -627,7 +641,7 @@ class Crew(BaseModel):
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
metrics += [agent._token_process.get_summary() for agent in self.agents]
|
||||
metrics += [agent.token_process.get_summary() for agent in self.agents]
|
||||
|
||||
self.usage_metrics = UsageMetrics()
|
||||
for metric in metrics:
|
||||
@@ -1174,19 +1188,22 @@ class Crew(BaseModel):
|
||||
agent.interpolate_inputs(inputs)
|
||||
|
||||
def _finish_execution(self, final_string_output: str) -> None:
|
||||
if self.max_rpm:
|
||||
if self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for agent in self.agents:
|
||||
if hasattr(agent, "_token_process"):
|
||||
token_sum = agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
||||
token_sum = self.manager_agent._token_process.get_summary()
|
||||
# Directly access token_process since it's now a field in BaseAgent
|
||||
token_sum = agent.token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
if self.manager_agent:
|
||||
# Directly access token_process since it's now a field in BaseAgent
|
||||
token_sum = self.manager_agent.token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
self.usage_metrics = total_usage_metrics
|
||||
return total_usage_metrics
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Type, get_args, get_origin
|
||||
from typing import Any, Callable, Optional, Type, get_args, get_origin
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -19,11 +19,21 @@ from crewai.tools.structured_tool import CrewStructuredTool
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
|
||||
|
||||
# Define a helper function with an explicit signature
|
||||
def default_cache_function(
|
||||
_args: Optional[Any] = None, _result: Optional[Any] = None
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True, # Allow conversion from ORM objects
|
||||
)
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
@@ -33,8 +43,10 @@ class BaseTool(BaseModel, ABC):
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
cache_function: Callable = lambda _args=None, _result=None: True
|
||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
||||
cache_function: Callable[[Optional[Any], Optional[Any]], bool] = (
|
||||
default_cache_function
|
||||
)
|
||||
"""Function used to determine if the tool should be cached."""
|
||||
result_as_answer: bool = False
|
||||
"""Flag to check if the tool should be the final agent answer."""
|
||||
|
||||
@@ -177,74 +189,43 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return origin.__name__
|
||||
|
||||
@property
|
||||
def get(self) -> Callable[[str, Any], Any]:
|
||||
# Instead of an inline lambda, we define a helper function with explicit types.
|
||||
def _getter(key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
|
||||
return _getter
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""The function that will be executed when the tool is called."""
|
||||
"""Tool implementation that requires a function."""
|
||||
|
||||
func: Callable
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
from_attributes=True,
|
||||
)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> "Tool":
|
||||
"""Create a Tool instance from a CrewStructuredTool.
|
||||
def to_langchain(self) -> Any:
|
||||
"""Convert to a LangChain-compatible tool."""
|
||||
try:
|
||||
from langchain_core.tools import Tool as LC_Tool
|
||||
except ImportError:
|
||||
raise ImportError("langchain_core is not installed")
|
||||
|
||||
This method takes a CrewStructuredTool object and converts it into a
|
||||
Tool instance. It ensures that the provided tool has a callable 'func'
|
||||
attribute and infers the argument schema if not explicitly provided.
|
||||
|
||||
Args:
|
||||
tool (Any): The CrewStructuredTool object to be converted.
|
||||
|
||||
Returns:
|
||||
Tool: A new Tool instance created from the provided CrewStructuredTool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool does not have a callable 'func' attribute.
|
||||
"""
|
||||
if not hasattr(tool, "func") or not callable(tool.func):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
|
||||
if args_schema is None:
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
field_info = Field(
|
||||
default=...,
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=getattr(tool, "name", "Unnamed Tool"),
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
# Use self._run (which is bound and calls self.func) so that the LC_Tool gets proper attributes.
|
||||
return LC_Tool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
func=self._run,
|
||||
args_schema=self.args_schema,
|
||||
)
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
def tool(*args):
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
|
||||
@@ -1,15 +1,52 @@
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
|
||||
|
||||
class TokenCalcHandler(CustomLogger):
|
||||
def __init__(self, token_cost_process: Optional[TokenProcess]):
|
||||
self.token_cost_process = token_cost_process
|
||||
class AbstractTokenCounter(ABC):
|
||||
"""
|
||||
Abstract base class for token counting callbacks.
|
||||
Implementations should track token usage from different LLM providers.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
"""Initialize with a TokenProcess instance to track tokens."""
|
||||
self.token_process = token_process
|
||||
|
||||
@abstractmethod
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
pass
|
||||
|
||||
|
||||
class LiteLLMTokenCounter(CustomLogger, AbstractTokenCounter):
|
||||
"""
|
||||
Token counter implementation for LiteLLM.
|
||||
Uses LiteLLM's CustomLogger interface to track token usage.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
AbstractTokenCounter.__init__(self, token_process)
|
||||
CustomLogger.__init__(self)
|
||||
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
if prompt_tokens > 0:
|
||||
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
self.token_process.sum_completion_tokens(completion_tokens)
|
||||
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
def log_success_event(
|
||||
self,
|
||||
@@ -18,7 +55,11 @@ class TokenCalcHandler(CustomLogger):
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> None:
|
||||
if self.token_cost_process is None:
|
||||
"""
|
||||
Process successful LLM call and extract token usage information.
|
||||
This method is called by LiteLLM after a successful completion.
|
||||
"""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@@ -26,18 +67,159 @@ class TokenCalcHandler(CustomLogger):
|
||||
if isinstance(response_obj, dict) and "usage" in response_obj:
|
||||
usage: Usage = response_obj["usage"]
|
||||
if usage:
|
||||
self.token_cost_process.sum_successful_requests(1)
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
|
||||
if hasattr(usage, "prompt_tokens"):
|
||||
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
elif isinstance(usage, dict) and "prompt_tokens" in usage:
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
|
||||
if hasattr(usage, "completion_tokens"):
|
||||
self.token_cost_process.sum_completion_tokens(
|
||||
usage.completion_tokens
|
||||
)
|
||||
completion_tokens = usage.completion_tokens
|
||||
elif isinstance(usage, dict) and "completion_tokens" in usage:
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
|
||||
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||
|
||||
# Handle cached tokens if available
|
||||
if (
|
||||
hasattr(usage, "prompt_tokens_details")
|
||||
and usage.prompt_tokens_details
|
||||
and usage.prompt_tokens_details.cached_tokens
|
||||
):
|
||||
self.token_cost_process.sum_cached_prompt_tokens(
|
||||
self.token_process.sum_cached_prompt_tokens(
|
||||
usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
|
||||
class LangChainTokenCounter(BaseCallbackHandler, AbstractTokenCounter):
|
||||
"""
|
||||
Token counter implementation for LangChain.
|
||||
Implements the necessary callback methods to track token usage from LangChain responses.
|
||||
"""
|
||||
|
||||
def __init__(self, token_process: Optional[TokenProcess] = None):
|
||||
BaseCallbackHandler.__init__(self)
|
||||
AbstractTokenCounter.__init__(self, token_process)
|
||||
|
||||
def update_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None:
|
||||
"""Update token usage counts in the token process."""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
if prompt_tokens > 0:
|
||||
self.token_process.sum_prompt_tokens(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
self.token_process.sum_completion_tokens(completion_tokens)
|
||||
|
||||
self.token_process.sum_successful_requests(1)
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_retriever(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def ignore_tools(self) -> bool:
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when LLM starts processing."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Called when LLM generates a new token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Called when LLM ends processing.
|
||||
Extracts token usage from LangChain response objects.
|
||||
"""
|
||||
if self.token_process is None:
|
||||
return
|
||||
|
||||
# Handle LangChain response format
|
||||
if hasattr(response, "llm_output") and isinstance(response.llm_output, dict):
|
||||
token_usage = response.llm_output.get("token_usage", {})
|
||||
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
|
||||
self.update_token_usage(prompt_tokens, completion_tokens)
|
||||
|
||||
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when LLM errors."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when a chain starts."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when a chain ends."""
|
||||
pass
|
||||
|
||||
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when a chain errors."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Called when a tool starts."""
|
||||
pass
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Called when a tool ends."""
|
||||
pass
|
||||
|
||||
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when a tool errors."""
|
||||
pass
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Called when text is generated."""
|
||||
pass
|
||||
|
||||
def on_agent_start(self, serialized: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Called when an agent starts."""
|
||||
pass
|
||||
|
||||
def on_agent_end(self, output: Any, **kwargs: Any) -> None:
|
||||
"""Called when an agent ends."""
|
||||
pass
|
||||
|
||||
def on_agent_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||
"""Called when an agent errors."""
|
||||
pass
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
class TokenCalcHandler(LiteLLMTokenCounter):
|
||||
"""
|
||||
Alias for LiteLLMTokenCounter.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -547,6 +547,7 @@ def test_crew_with_delegating_agents():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_with_delegating_agents_should_not_override_task_tools():
|
||||
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -18,15 +18,15 @@ def test_llm_callback_replacement():
|
||||
llm1 = LLM(model="gpt-4o-mini")
|
||||
llm2 = LLM(model="gpt-4o-mini")
|
||||
|
||||
calc_handler_1 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler_1 = TokenCalcHandler(token_process=TokenProcess())
|
||||
calc_handler_2 = TokenCalcHandler(token_process=TokenProcess())
|
||||
|
||||
result1 = llm1.call(
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
callbacks=[calc_handler_1],
|
||||
)
|
||||
print("result1:", result1)
|
||||
usage_metrics_1 = calc_handler_1.token_cost_process.get_summary()
|
||||
usage_metrics_1 = calc_handler_1.token_process.get_summary()
|
||||
print("usage_metrics_1:", usage_metrics_1)
|
||||
|
||||
result2 = llm2.call(
|
||||
@@ -35,13 +35,13 @@ def test_llm_callback_replacement():
|
||||
)
|
||||
sleep(5)
|
||||
print("result2:", result2)
|
||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
||||
usage_metrics_2 = calc_handler_2.token_process.get_summary()
|
||||
print("usage_metrics_2:", usage_metrics_2)
|
||||
|
||||
# The first handler should not have been updated
|
||||
assert usage_metrics_1.successful_requests == 1
|
||||
assert usage_metrics_2.successful_requests == 1
|
||||
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
||||
assert usage_metrics_1 == calc_handler_1.token_process.get_summary()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -57,14 +57,14 @@ def test_llm_call_with_string_input():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_string_input_and_callbacks():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
calc_handler = TokenCalcHandler(token_process=TokenProcess())
|
||||
|
||||
# Test the call method with a string input and callbacks
|
||||
result = llm.call(
|
||||
"Tell me a joke.",
|
||||
callbacks=[calc_handler],
|
||||
)
|
||||
usage_metrics = calc_handler.token_cost_process.get_summary()
|
||||
usage_metrics = calc_handler.token_process.get_summary()
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result.strip()) > 0
|
||||
@@ -285,6 +285,7 @@ def test_o3_mini_reasoning_effort_medium():
|
||||
assert isinstance(result, str)
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
def test_context_window_validation():
|
||||
"""Test that context window validation works correctly."""
|
||||
# Test valid window size
|
||||
|
||||
189
tests/utilities/test_token_tracking.py
Normal file
189
tests/utilities/test_token_tracking.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test module for token tracking functionality in CrewAI.
|
||||
This tests both direct LangChain models and LiteLLM integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from crewai import Crew, Process, Task
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.langchain_agent_adapter import LangChainAgentAdapter
|
||||
from crewai.utilities.token_counter_callback import (
|
||||
LangChainTokenCounter,
|
||||
LiteLLMTokenCounter,
|
||||
)
|
||||
|
||||
|
||||
def get_weather(location: str = "San Francisco"):
|
||||
"""Simulates fetching current weather data for a given location."""
|
||||
# In a real implementation, you could replace this with an API call.
|
||||
return f"Current weather in {location}: Sunny, 25°C"
|
||||
|
||||
|
||||
class TestTokenTracking:
|
||||
"""Test suite for token tracking functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def weather_tool(self):
|
||||
"""Create a simple weather tool for testing."""
|
||||
return Tool(
|
||||
name="Weather",
|
||||
func=get_weather,
|
||||
description="Useful for fetching current weather information for a given location.",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_response(self):
|
||||
"""Create a mock OpenAI response with token usage information."""
|
||||
return {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
}
|
||||
|
||||
def test_token_process_basic(self):
|
||||
"""Test basic functionality of TokenProcess class."""
|
||||
token_process = TokenProcess()
|
||||
|
||||
# Test adding prompt tokens
|
||||
token_process.sum_prompt_tokens(100)
|
||||
assert token_process.prompt_tokens == 100
|
||||
|
||||
# Test adding completion tokens
|
||||
token_process.sum_completion_tokens(50)
|
||||
assert token_process.completion_tokens == 50
|
||||
|
||||
# Test adding successful requests
|
||||
token_process.sum_successful_requests(1)
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
# Test getting summary
|
||||
summary = token_process.get_summary()
|
||||
assert summary.prompt_tokens == 100
|
||||
assert summary.completion_tokens == 50
|
||||
assert summary.total_tokens == 150
|
||||
assert summary.successful_requests == 1
|
||||
|
||||
@patch("litellm.completion")
|
||||
def test_litellm_token_counter(self, mock_completion):
|
||||
"""Test LiteLLMTokenCounter with a mock response."""
|
||||
# Setup
|
||||
token_process = TokenProcess()
|
||||
counter = LiteLLMTokenCounter(token_process)
|
||||
|
||||
# Mock the response
|
||||
mock_completion.return_value = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
}
|
||||
|
||||
# Simulate a successful LLM call
|
||||
counter.log_success_event(
|
||||
kwargs={},
|
||||
response_obj=mock_completion.return_value,
|
||||
start_time=0,
|
||||
end_time=1,
|
||||
)
|
||||
|
||||
# Verify token counts were updated
|
||||
assert token_process.prompt_tokens == 100
|
||||
assert token_process.completion_tokens == 50
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
def test_langchain_token_counter(self):
|
||||
"""Test LangChainTokenCounter with a mock response."""
|
||||
# Setup
|
||||
token_process = TokenProcess()
|
||||
counter = LangChainTokenCounter(token_process)
|
||||
|
||||
# Create a mock LangChain response
|
||||
mock_response = MagicMock()
|
||||
mock_response.llm_output = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
}
|
||||
}
|
||||
|
||||
# Simulate a successful LLM call
|
||||
counter.on_llm_end(mock_response)
|
||||
|
||||
# Verify token counts were updated
|
||||
assert token_process.prompt_tokens == 100
|
||||
assert token_process.completion_tokens == 50
|
||||
assert token_process.successful_requests == 1
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY"),
|
||||
reason="OPENAI_API_KEY environment variable not set",
|
||||
)
|
||||
def test_langchain_agent_adapter_token_tracking(self, weather_tool):
|
||||
"""
|
||||
Integration test for token tracking with LangChainAgentAdapter.
|
||||
This test requires an OpenAI API key.
|
||||
"""
|
||||
# Skip if LangGraph is not installed
|
||||
try:
|
||||
from langgraph.prebuilt import ToolNode
|
||||
except ImportError:
|
||||
pytest.skip("LangGraph is not installed. Install it with: uv add langgraph")
|
||||
|
||||
# Initialize a ChatOpenAI model
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
|
||||
# Create a LangChainAgentAdapter with the direct LLM
|
||||
agent = LangChainAgentAdapter(
|
||||
langchain_agent=llm,
|
||||
tools=[weather_tool],
|
||||
role="Weather Agent",
|
||||
goal="Provide current weather information for the requested location.",
|
||||
backstory="An expert weather provider that fetches current weather information using simulated data.",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create a weather task for the agent
|
||||
task = Task(
|
||||
description="Fetch the current weather for San Francisco.",
|
||||
expected_output="A weather report showing current conditions in San Francisco.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a crew with the single agent and task
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=True,
|
||||
process=Process.sequential,
|
||||
)
|
||||
|
||||
# Execute the crew
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify token usage was tracked
|
||||
assert result.token_usage is not None
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests > 0
|
||||
|
||||
# Also verify token usage directly from the agent
|
||||
usage = agent.token_process.get_summary()
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.successful_requests > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
Reference in New Issue
Block a user