more wip.

This commit is contained in:
Brandon Hancock
2024-06-26 20:18:23 -07:00
parent be0a4c2fe5
commit 764234c426
3 changed files with 32 additions and 36 deletions

View File

@@ -1,6 +1,6 @@
import os
import uuid
from copy import deepcopy
from copy import copy, deepcopy
from typing import Any, Dict, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent
@@ -10,14 +10,14 @@ from langchain_core.agents import AgentAction
from langchain_core.callbacks import BaseCallbackHandler
from langchain_openai import ChatOpenAI
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
@@ -379,16 +379,17 @@ class Agent(BaseModel):
"tools",
"tools_handler",
"cache_handler",
"llm", # TODO: THIS GET'S THINGS WORKING AGAIN.``
"llm",
}
print("LLM IN COPY", self.llm.model_name)
existing_llm = copy(self.llm)
# TODO: EXPAND ON WHY THIS IS NEEDED
# RESET LLM CALLBACKS
existing_llm.callbacks = []
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = Agent(**copied_data)
print("COPIED AGENT LLM", copied_agent.llm.model_name)
copied_agent.tools = deepcopy(self.tools)
copied_agent = Agent(**copied_data, llm=existing_llm, tools=self.tools)
return copied_agent

View File

@@ -303,16 +303,6 @@ class Crew(BaseModel):
# TODO: I would expect we would want to merge the usage metrics from each crew execution
results.append(output)
print("CREW USAGE METRICS:", crew.usage_metrics)
print(
"ORIGINAL AGENT USAGE METRICS",
[agent._token_process.get_summary() for agent in self.agents],
)
print(
"COPIED AGENT USAGE METRICS",
[agent._token_process.get_summary() for agent in crew.agents],
)
return results
async def kickoff_async(
@@ -321,13 +311,17 @@ class Crew(BaseModel):
"""Asynchronous kickoff method to start the crew execution."""
return await asyncio.to_thread(self.kickoff, inputs)
# TODO: IF THERE ARE MULTIPLE INPUTS, THE USAGE METRICS FOR FIRST ONE COMES BACK AS 0.
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[Any]:
async def run_crew(input_data):
crew = self.copy()
crew_copies = [self.copy() for _ in inputs]
async def run_crew(crew, input_data):
return await crew.kickoff_async(inputs=input_data)
tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs]
tasks = [
asyncio.create_task(run_crew(crew_copies[i], inputs[i]))
for i in range(len(inputs))
]
results = await asyncio.gather(*tasks)
@@ -379,16 +373,15 @@ class Crew(BaseModel):
if self.output_log_file:
self._file_handler.log(agent=role, task=task_output, status="completed")
# Update token usage for the current task
current_token_usage = task.agent._token_process.get_summary()
for agent in self.agents:
agent_token_usage = agent._token_process.get_summary()
for key in total_token_usage:
total_token_usage[key] += current_token_usage.get(key, 0)
total_token_usage[key] += agent_token_usage.get(key, 0)
self._finish_execution(task_output)
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
# TODO: TEST AND FIX
return self._format_output(task_output, total_token_usage)
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
@@ -432,10 +425,10 @@ class Crew(BaseModel):
agent=manager.role, task=task_output, status="completed"
)
# TODO: GET TOKENS USAGE CALCULATED INCLUDING MANAGER
self._finish_execution(task_output)
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
manager_token_usage = manager._token_process.get_summary()
# TODO: TEST AND FIX
return (
self._format_output(task_output, manager_token_usage),
manager_token_usage,
@@ -497,6 +490,8 @@ class Crew(BaseModel):
Formats the output of the crew execution.
If full_output is True, then returned data type will be a dictionary else returned outputs are string
"""
print("token_usage passed to _format_output", token_usage)
if self.full_output:
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
"final_output": output,

View File

@@ -35,18 +35,18 @@ class TokenProcess:
class TokenCalcHandler(BaseCallbackHandler):
id = uuid.uuid4() # TODO: REMOVE THIS
model: str = ""
model_name: str = ""
token_cost_process: TokenProcess
def __init__(self, model, token_cost_process):
self.model = model
def __init__(self, model_name, token_cost_process):
self.model_name = model_name
self.token_cost_process = token_cost_process
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
try:
encoding = tiktoken.encoding_for_model(self.model)
encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")