Clean up code for review

This commit is contained in:
Brandon Hancock
2024-07-01 14:56:42 -04:00
parent 6a47eb4f9e
commit 60c8f86345
2 changed files with 26 additions and 44 deletions

View File

@@ -1,5 +1,5 @@
import os
from copy import copy
from copy import copy as shallow_copy
from typing import Any, List, Optional, Tuple
from langchain.agents.agent import RunnableAgent
@@ -76,10 +76,6 @@ class Agent(BaseAgent):
response_template: Optional[str] = Field(
default=None, description="Response format for the agent."
)
allow_code_execution: Optional[bool] = Field(
default=False, description="Enable code execution for the agent."
)
allow_code_execution: Optional[bool] = Field(
default=False, description="Enable code execution for the agent."
)
@@ -286,8 +282,8 @@ class Agent(BaseAgent):
"llm",
}
# TODO: TEST REMOVING THIS AND SEE IF ANYTHING CHANGES
existing_llm = copy(self.llm)
# Copy llm and clear callbacks
existing_llm = shallow_copy(self.llm)
existing_llm.callbacks = []
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}

View File

@@ -5,15 +5,15 @@ from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import BaseCallbackHandler
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
Json,
PrivateAttr,
field_validator,
model_validator,
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
Json,
PrivateAttr,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
@@ -69,7 +69,7 @@ class Crew(BaseModel):
_train: Optional[bool] = PrivateAttr(default=False)
_train_iteration: Optional[int] = PrivateAttr()
cache: bool = Field(default=False)
cache: bool = Field(default=True)
model_config = ConfigDict(arbitrary_types_allowed=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[BaseAgent] = Field(default_factory=list)
@@ -294,16 +294,13 @@ class Crew(BaseModel):
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
agent.crew = self # type: ignore[attr-defined]
# TODO: Create an AgentFunctionCalling protocol for future refactoring
if (
hasattr(agent, "function_calling_llm")
and not agent.function_calling_llm
):
if not agent.function_calling_llm:
agent.function_calling_llm = self.function_calling_llm
if hasattr(agent, "allow_code_execution") and agent.allow_code_execution:
if agent.allow_code_execution:
agent.tools += agent.get_code_execution_tools()
if hasattr(agent, "step_callback") and not agent.step_callback:
if not agent.step_callback:
agent.step_callback = self.step_callback
agent.create_agent_executor()
@@ -321,9 +318,7 @@ class Crew(BaseModel):
raise NotImplementedError(
f"The process '{self.process}' is not implemented yet."
)
metrics = metrics + [
agent._token_process.get_summary() for agent in self.agents
]
metrics += [agent._token_process.get_summary() for agent in self.agents]
self.usage_metrics = {
key: sum([m[key] for m in metrics if m is not None]) for key in metrics[0]
@@ -404,7 +399,6 @@ class Crew(BaseModel):
]
if len(self.agents) > 1 and len(agents_for_delegation) > 0:
task.tools += task.agent.get_delegation_tools(agents_for_delegation)
task.tools += task.agent.get_delegation_tools(agents_for_delegation)
role = task.agent.role if task.agent is not None else "None"
self._logger.log("debug", f"== Working Agent: {role}", color="bold_purple")
@@ -479,12 +473,9 @@ class Crew(BaseModel):
self._finish_execution(task_output)
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
token_usage = self._calculate_usage_metrics()
token_usage = self.calculate_usage_metrics()
return (
self._format_output(task_output, token_usage),
token_usage,
)
return self._format_output(task_output, token_usage), token_usage
def copy(self):
"""Create a deep copy of the Crew."""
@@ -558,9 +549,8 @@ class Crew(BaseModel):
self._rpm_controller.stop_rpm_counter()
self._telemetry.end_crew(self, output)
def _calculate_usage_metrics(
self,
) -> Dict[str, int]:
def calculate_usage_metrics(self) -> Dict[str, int]:
"""Calculates and returns the usage metrics."""
total_usage_metrics = {
"total_tokens": 0,
"prompt_tokens": 0,
@@ -571,17 +561,13 @@ class Crew(BaseModel):
for agent in self.agents:
if hasattr(agent, "_token_process"):
token_sum = agent._token_process.get_summary()
total_usage_metrics = {
key: total_usage_metrics[key] + token_sum[key]
for key in total_usage_metrics
}
for key in total_usage_metrics:
total_usage_metrics[key] += token_sum.get(key, 0)
if self.manager_agent:
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
token_sum = self.manager_agent._token_process.get_summary()
total_usage_metrics = {
key: total_usage_metrics[key] + token_sum[key]
for key in total_usage_metrics
}
for key in total_usage_metrics:
total_usage_metrics[key] += token_sum.get(key, 0)
return total_usage_metrics