mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
more wip.
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
@@ -10,14 +10,14 @@ from langchain_core.agents import AgentAction
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
UUID4,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
@@ -379,16 +379,17 @@ class Agent(BaseModel):
|
||||
"tools",
|
||||
"tools_handler",
|
||||
"cache_handler",
|
||||
"llm", # TODO: THIS GET'S THINGS WORKING AGAIN.``
|
||||
"llm",
|
||||
}
|
||||
|
||||
print("LLM IN COPY", self.llm.model_name)
|
||||
existing_llm = copy(self.llm)
|
||||
# TODO: EXPAND ON WHY THIS IS NEEDED
|
||||
# RESET LLM CALLBACKS
|
||||
existing_llm.callbacks = []
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
|
||||
copied_agent = Agent(**copied_data)
|
||||
print("COPIED AGENT LLM", copied_agent.llm.model_name)
|
||||
copied_agent.tools = deepcopy(self.tools)
|
||||
copied_agent = Agent(**copied_data, llm=existing_llm, tools=self.tools)
|
||||
|
||||
return copied_agent
|
||||
|
||||
|
||||
@@ -303,16 +303,6 @@ class Crew(BaseModel):
|
||||
# TODO: I would expect we would want to merge the usage metrics from each crew execution
|
||||
results.append(output)
|
||||
|
||||
print("CREW USAGE METRICS:", crew.usage_metrics)
|
||||
print(
|
||||
"ORIGINAL AGENT USAGE METRICS",
|
||||
[agent._token_process.get_summary() for agent in self.agents],
|
||||
)
|
||||
print(
|
||||
"COPIED AGENT USAGE METRICS",
|
||||
[agent._token_process.get_summary() for agent in crew.agents],
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def kickoff_async(
|
||||
@@ -321,13 +311,17 @@ class Crew(BaseModel):
|
||||
"""Asynchronous kickoff method to start the crew execution."""
|
||||
return await asyncio.to_thread(self.kickoff, inputs)
|
||||
|
||||
# TODO: IF THERE ARE MULTIPLE INPUTS, THE USAGE METRICS FOR FIRST ONE COMES BACK AS 0.
|
||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[Any]:
|
||||
async def run_crew(input_data):
|
||||
crew = self.copy()
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def run_crew(crew, input_data):
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs]
|
||||
tasks = [
|
||||
asyncio.create_task(run_crew(crew_copies[i], inputs[i]))
|
||||
for i in range(len(inputs))
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -379,16 +373,15 @@ class Crew(BaseModel):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(agent=role, task=task_output, status="completed")
|
||||
|
||||
# Update token usage for the current task
|
||||
current_token_usage = task.agent._token_process.get_summary()
|
||||
for agent in self.agents:
|
||||
agent_token_usage = agent._token_process.get_summary()
|
||||
for key in total_token_usage:
|
||||
total_token_usage[key] += current_token_usage.get(key, 0)
|
||||
total_token_usage[key] += agent_token_usage.get(key, 0)
|
||||
|
||||
self._finish_execution(task_output)
|
||||
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
|
||||
|
||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||
# TODO: TEST AND FIX
|
||||
return self._format_output(task_output, total_token_usage)
|
||||
|
||||
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
|
||||
@@ -432,10 +425,10 @@ class Crew(BaseModel):
|
||||
agent=manager.role, task=task_output, status="completed"
|
||||
)
|
||||
|
||||
# TODO: GET TOKENS USAGE CALCULATED INCLUDING MANAGER
|
||||
self._finish_execution(task_output)
|
||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||
manager_token_usage = manager._token_process.get_summary()
|
||||
# TODO: TEST AND FIX
|
||||
return (
|
||||
self._format_output(task_output, manager_token_usage),
|
||||
manager_token_usage,
|
||||
@@ -497,6 +490,8 @@ class Crew(BaseModel):
|
||||
Formats the output of the crew execution.
|
||||
If full_output is True, then returned data type will be a dictionary else returned outputs are string
|
||||
"""
|
||||
|
||||
print("token_usage passed to _format_output", token_usage)
|
||||
if self.full_output:
|
||||
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
|
||||
"final_output": output,
|
||||
|
||||
@@ -35,18 +35,18 @@ class TokenProcess:
|
||||
|
||||
class TokenCalcHandler(BaseCallbackHandler):
|
||||
id = uuid.uuid4() # TODO: REMOVE THIS
|
||||
model: str = ""
|
||||
model_name: str = ""
|
||||
token_cost_process: TokenProcess
|
||||
|
||||
def __init__(self, model, token_cost_process):
|
||||
self.model = model
|
||||
def __init__(self, model_name, token_cost_process):
|
||||
self.model_name = model_name
|
||||
self.token_cost_process = token_cost_process
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(self.model)
|
||||
encoding = tiktoken.encoding_for_model(self.model_name)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user