mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
more wip.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from copy import deepcopy
|
from copy import copy, deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import RunnableAgent
|
from langchain.agents.agent import RunnableAgent
|
||||||
@@ -10,14 +10,14 @@ from langchain_core.agents import AgentAction
|
|||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
InstanceOf,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
field_validator,
|
field_validator,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from pydantic_core import PydanticCustomError
|
from pydantic_core import PydanticCustomError
|
||||||
|
|
||||||
@@ -379,16 +379,17 @@ class Agent(BaseModel):
|
|||||||
"tools",
|
"tools",
|
||||||
"tools_handler",
|
"tools_handler",
|
||||||
"cache_handler",
|
"cache_handler",
|
||||||
"llm", # TODO: THIS GET'S THINGS WORKING AGAIN.``
|
"llm",
|
||||||
}
|
}
|
||||||
|
|
||||||
print("LLM IN COPY", self.llm.model_name)
|
existing_llm = copy(self.llm)
|
||||||
|
# TODO: EXPAND ON WHY THIS IS NEEDED
|
||||||
|
# RESET LLM CALLBACKS
|
||||||
|
existing_llm.callbacks = []
|
||||||
copied_data = self.model_dump(exclude=exclude)
|
copied_data = self.model_dump(exclude=exclude)
|
||||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||||
|
|
||||||
copied_agent = Agent(**copied_data)
|
copied_agent = Agent(**copied_data, llm=existing_llm, tools=self.tools)
|
||||||
print("COPIED AGENT LLM", copied_agent.llm.model_name)
|
|
||||||
copied_agent.tools = deepcopy(self.tools)
|
|
||||||
|
|
||||||
return copied_agent
|
return copied_agent
|
||||||
|
|
||||||
|
|||||||
@@ -303,16 +303,6 @@ class Crew(BaseModel):
|
|||||||
# TODO: I would expect we would want to merge the usage metrics from each crew execution
|
# TODO: I would expect we would want to merge the usage metrics from each crew execution
|
||||||
results.append(output)
|
results.append(output)
|
||||||
|
|
||||||
print("CREW USAGE METRICS:", crew.usage_metrics)
|
|
||||||
print(
|
|
||||||
"ORIGINAL AGENT USAGE METRICS",
|
|
||||||
[agent._token_process.get_summary() for agent in self.agents],
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"COPIED AGENT USAGE METRICS",
|
|
||||||
[agent._token_process.get_summary() for agent in crew.agents],
|
|
||||||
)
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def kickoff_async(
|
async def kickoff_async(
|
||||||
@@ -321,13 +311,17 @@ class Crew(BaseModel):
|
|||||||
"""Asynchronous kickoff method to start the crew execution."""
|
"""Asynchronous kickoff method to start the crew execution."""
|
||||||
return await asyncio.to_thread(self.kickoff, inputs)
|
return await asyncio.to_thread(self.kickoff, inputs)
|
||||||
|
|
||||||
|
# TODO: IF THERE ARE MULTIPLE INPUTS, THE USAGE METRICS FOR FIRST ONE COMES BACK AS 0.
|
||||||
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[Any]:
|
async def kickoff_for_each_async(self, inputs: List[Dict]) -> List[Any]:
|
||||||
async def run_crew(input_data):
|
crew_copies = [self.copy() for _ in inputs]
|
||||||
crew = self.copy()
|
|
||||||
|
|
||||||
|
async def run_crew(crew, input_data):
|
||||||
return await crew.kickoff_async(inputs=input_data)
|
return await crew.kickoff_async(inputs=input_data)
|
||||||
|
|
||||||
tasks = [asyncio.create_task(run_crew(input_data)) for input_data in inputs]
|
tasks = [
|
||||||
|
asyncio.create_task(run_crew(crew_copies[i], inputs[i]))
|
||||||
|
for i in range(len(inputs))
|
||||||
|
]
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
@@ -379,16 +373,15 @@ class Crew(BaseModel):
|
|||||||
if self.output_log_file:
|
if self.output_log_file:
|
||||||
self._file_handler.log(agent=role, task=task_output, status="completed")
|
self._file_handler.log(agent=role, task=task_output, status="completed")
|
||||||
|
|
||||||
# Update token usage for the current task
|
for agent in self.agents:
|
||||||
current_token_usage = task.agent._token_process.get_summary()
|
agent_token_usage = agent._token_process.get_summary()
|
||||||
for key in total_token_usage:
|
for key in total_token_usage:
|
||||||
total_token_usage[key] += current_token_usage.get(key, 0)
|
total_token_usage[key] += agent_token_usage.get(key, 0)
|
||||||
|
|
||||||
self._finish_execution(task_output)
|
self._finish_execution(task_output)
|
||||||
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
|
# type: ignore # Item "None" of "Agent | None" has no attribute "_token_process")
|
||||||
|
|
||||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||||
# TODO: TEST AND FIX
|
|
||||||
return self._format_output(task_output, total_token_usage)
|
return self._format_output(task_output, total_token_usage)
|
||||||
|
|
||||||
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
|
def _run_hierarchical_process(self) -> Union[str, Dict[str, Any]]:
|
||||||
@@ -432,10 +425,10 @@ class Crew(BaseModel):
|
|||||||
agent=manager.role, task=task_output, status="completed"
|
agent=manager.role, task=task_output, status="completed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: GET TOKENS USAGE CALCULATED INCLUDING MANAGER
|
||||||
self._finish_execution(task_output)
|
self._finish_execution(task_output)
|
||||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||||
manager_token_usage = manager._token_process.get_summary()
|
manager_token_usage = manager._token_process.get_summary()
|
||||||
# TODO: TEST AND FIX
|
|
||||||
return (
|
return (
|
||||||
self._format_output(task_output, manager_token_usage),
|
self._format_output(task_output, manager_token_usage),
|
||||||
manager_token_usage,
|
manager_token_usage,
|
||||||
@@ -497,6 +490,8 @@ class Crew(BaseModel):
|
|||||||
Formats the output of the crew execution.
|
Formats the output of the crew execution.
|
||||||
If full_output is True, then returned data type will be a dictionary else returned outputs are string
|
If full_output is True, then returned data type will be a dictionary else returned outputs are string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
print("token_usage passed to _format_output", token_usage)
|
||||||
if self.full_output:
|
if self.full_output:
|
||||||
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
|
return { # type: ignore # Incompatible return value type (got "dict[str, Sequence[str | TaskOutput | None]]", expected "str")
|
||||||
"final_output": output,
|
"final_output": output,
|
||||||
|
|||||||
@@ -35,18 +35,18 @@ class TokenProcess:
|
|||||||
|
|
||||||
class TokenCalcHandler(BaseCallbackHandler):
|
class TokenCalcHandler(BaseCallbackHandler):
|
||||||
id = uuid.uuid4() # TODO: REMOVE THIS
|
id = uuid.uuid4() # TODO: REMOVE THIS
|
||||||
model: str = ""
|
model_name: str = ""
|
||||||
token_cost_process: TokenProcess
|
token_cost_process: TokenProcess
|
||||||
|
|
||||||
def __init__(self, model, token_cost_process):
|
def __init__(self, model_name, token_cost_process):
|
||||||
self.model = model
|
self.model_name = model_name
|
||||||
self.token_cost_process = token_cost_process
|
self.token_cost_process = token_cost_process
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.encoding_for_model(self.model)
|
encoding = tiktoken.encoding_for_model(self.model_name)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user