more wip.

This commit is contained in:
Brandon Hancock
2024-06-26 20:18:23 -07:00
parent be0a4c2fe5
commit 764234c426
3 changed files with 32 additions and 36 deletions

View File

@@ -1,6 +1,6 @@
import os import os
import uuid import uuid
from copy import deepcopy from copy import copy, deepcopy
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent from langchain.agents.agent import RunnableAgent
@@ -10,14 +10,14 @@ from langchain_core.agents import AgentAction
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic import ( from pydantic import (
UUID4, UUID4,
BaseModel, BaseModel,
ConfigDict, ConfigDict,
Field, Field,
InstanceOf, InstanceOf,
PrivateAttr, PrivateAttr,
field_validator, field_validator,
model_validator, model_validator,
) )
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
@@ -379,16 +379,17 @@ class Agent(BaseModel):
"tools", "tools",
"tools_handler", "tools_handler",
"cache_handler", "cache_handler",
"llm", # TODO: THIS GET'S THINGS WORKING AGAIN.`` "llm",
} }
print("LLM IN COPY", self.llm.model_name) existing_llm = copy(self.llm)
# TODO: EXPAND ON WHY THIS IS NEEDED
# RESET LLM CALLBACKS
existing_llm.callbacks = []
copied_data = self.model_dump(exclude=exclude) copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None} copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = Agent(**copied_data) copied_agent = Agent(**copied_data, llm=existing_llm, tools=self.tools)
print("COPIED AGENT LLM", copied_agent.llm.model_name)
copied_agent.tools = deepcopy(self.tools)
return copied_agent return copied_agent

View File

@@ -303,16 +303,6 @@ class Crew(BaseModel):
# TODO: I would expect we would want to merge the usage metrics from each crew execution # TODO: I would expect we would want to merge the usage metrics from each crew execution
results.append(output) results.append(output)
print("CREW USAGE METRICS:", crew.usage_metrics)
print(
"ORIGINAL AGENT USAGE METRICS",
[agent._token_process.get_summary() for agent in self.agents],
)
print(
"COPIED AGENT USAGE METRICS",
[agent._token_process.get_summary() for agent in crew.agents],
)
return results return results
async def kickoff_async( async def kickoff_async(
@@ -321,13 +311,17 @@ class Crew(BaseModel):
"""Asynchronous kickoff method to start the crew execution.""" """Asynchronous kickoff method to start the crew execution."""
return await asyncio.to_thread(self.kickoff, inputs) return await asyncio.to_thread(self.kickoff, inputs)
# TODO: IF THERE ARE MULTIPLE INPUTS, THE USAGE METRICS FOR FIRST ONE COMES BACK AS 0.
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[Any]: async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[Any]:
async def run_crew(input_data): crew_copies = [self.copy() for _ in inputs]
crew = self.copy()
async def run_crew(crew, input_data):
return await crew.kickoff_async(inputs=input_data) return await crew.kickoff_async(inputs=input_data)
tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs] tasks = [
asyncio.create_task(run_crew(crew_copies[i], inputs[i]))
for i in range(len(inputs))
]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -379,16 +373,15 @@ class Crew(BaseModel):
if self.output_log_file: if self.output_log_file:
self._file_handler.log(agent=role, task=task_output, status="completed") self._file_handler.log(agent=role, task=task_output, status="completed")
# Update token usage for the current task for agent in self.agents:
current_token_usage = task.agent._token_process.get_summary() agent_token_usage = agent._token_process.get_summary()
for key in total_token_usage: for key in total_token_usage:
total_token_usage[key] += current_token_usage.get(key, 0) total_token_usage[key] += agent_token_usage.get(key, 0)
self._finish_execution(task_output) self._finish_execution(task_output)
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process") # type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str") # type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
# TODO: TEST AND FIX
return self._format_output(task_output, total_token_usage) return self._format_output(task_output, total_token_usage)
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]: def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
@@ -432,10 +425,10 @@ class Crew(BaseModel):
agent=manager.role, task=task_output, status="completed" agent=manager.role, task=task_output, status="completed"
) )
# TODO: GET TOKENS USAGE CALCULATED INCLUDING MANAGER
self._finish_execution(task_output) self._finish_execution(task_output)
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str") # type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
manager_token_usage = manager._token_process.get_summary() manager_token_usage = manager._token_process.get_summary()
# TODO: TEST AND FIX
return ( return (
self._format_output(task_output, manager_token_usage), self._format_output(task_output, manager_token_usage),
manager_token_usage, manager_token_usage,
@@ -497,6 +490,8 @@ class Crew(BaseModel):
Formats the output of the crew execution. Formats the output of the crew execution.
If full_output is True, then returned data type will be a dictionary else returned outputs are string If full_output is True, then returned data type will be a dictionary else returned outputs are string
""" """
print("token_usage passed to _format_output", token_usage)
if self.full_output: if self.full_output:
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str") return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
"final_output": output, "final_output": output,

View File

@@ -35,18 +35,18 @@ class TokenProcess:
class TokenCalcHandler(BaseCallbackHandler): class TokenCalcHandler(BaseCallbackHandler):
id = uuid.uuid4() # TODO: REMOVE THIS id = uuid.uuid4() # TODO: REMOVE THIS
model: str = "" model_name: str = ""
token_cost_process: TokenProcess token_cost_process: TokenProcess
def __init__(self, model, token_cost_process): def __init__(self, model_name, token_cost_process):
self.model = model self.model_name = model_name
self.token_cost_process = token_cost_process self.token_cost_process = token_cost_process
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None: ) -> None:
try: try:
encoding = tiktoken.encoding_for_model(self.model) encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError: except KeyError:
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")