mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 15:48:23 +00:00
Clean up code for review
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from copy import copy
|
from copy import copy as shallow_copy
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.agents.agent import RunnableAgent
|
from langchain.agents.agent import RunnableAgent
|
||||||
@@ -76,10 +76,6 @@ class Agent(BaseAgent):
|
|||||||
response_template: Optional[str] = Field(
|
response_template: Optional[str] = Field(
|
||||||
default=None, description="Response format for the agent."
|
default=None, description="Response format for the agent."
|
||||||
)
|
)
|
||||||
|
|
||||||
allow_code_execution: Optional[bool] = Field(
|
|
||||||
default=False, description="Enable code execution for the agent."
|
|
||||||
)
|
|
||||||
allow_code_execution: Optional[bool] = Field(
|
allow_code_execution: Optional[bool] = Field(
|
||||||
default=False, description="Enable code execution for the agent."
|
default=False, description="Enable code execution for the agent."
|
||||||
)
|
)
|
||||||
@@ -286,8 +282,8 @@ class Agent(BaseAgent):
|
|||||||
"llm",
|
"llm",
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: TEST REMOVING THIS AND SEE IF ANYTHING CHANGES
|
# Copy llm and clear callbacks
|
||||||
existing_llm = copy(self.llm)
|
existing_llm = shallow_copy(self.llm)
|
||||||
existing_llm.callbacks = []
|
existing_llm.callbacks = []
|
||||||
copied_data = self.model_dump(exclude=exclude)
|
copied_data = self.model_dump(exclude=exclude)
|
||||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
InstanceOf,
|
||||||
Json,
|
Json,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
field_validator,
|
field_validator,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from pydantic_core import PydanticCustomError
|
from pydantic_core import PydanticCustomError
|
||||||
|
|
||||||
@@ -69,7 +69,7 @@ class Crew(BaseModel):
|
|||||||
_train: Optional[bool] = PrivateAttr(default=False)
|
_train: Optional[bool] = PrivateAttr(default=False)
|
||||||
_train_iteration: Optional[int] = PrivateAttr()
|
_train_iteration: Optional[int] = PrivateAttr()
|
||||||
|
|
||||||
cache: bool = Field(default=False)
|
cache: bool = Field(default=True)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
tasks: List[Task] = Field(default_factory=list)
|
tasks: List[Task] = Field(default_factory=list)
|
||||||
agents: List[BaseAgent] = Field(default_factory=list)
|
agents: List[BaseAgent] = Field(default_factory=list)
|
||||||
@@ -294,16 +294,13 @@ class Crew(BaseModel):
|
|||||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||||
agent.crew = self # type: ignore[attr-defined]
|
agent.crew = self # type: ignore[attr-defined]
|
||||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||||
if (
|
if not agent.function_calling_llm:
|
||||||
hasattr(agent, "function_calling_llm")
|
|
||||||
and not agent.function_calling_llm
|
|
||||||
):
|
|
||||||
agent.function_calling_llm = self.function_calling_llm
|
agent.function_calling_llm = self.function_calling_llm
|
||||||
|
|
||||||
if hasattr(agent, "allow_code_execution") and agent.allow_code_execution:
|
if agent.allow_code_execution:
|
||||||
agent.tools += agent.get_code_execution_tools()
|
agent.tools += agent.get_code_execution_tools()
|
||||||
|
|
||||||
if hasattr(agent, "step_callback") and not agent.step_callback:
|
if not agent.step_callback:
|
||||||
agent.step_callback = self.step_callback
|
agent.step_callback = self.step_callback
|
||||||
|
|
||||||
agent.create_agent_executor()
|
agent.create_agent_executor()
|
||||||
@@ -321,9 +318,7 @@ class Crew(BaseModel):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The process '{self.process}' is not implemented yet."
|
f"The process '{self.process}' is not implemented yet."
|
||||||
)
|
)
|
||||||
metrics = metrics + [
|
metrics += [agent._token_process.get_summary() for agent in self.agents]
|
||||||
agent._token_process.get_summary() for agent in self.agents
|
|
||||||
]
|
|
||||||
|
|
||||||
self.usage_metrics = {
|
self.usage_metrics = {
|
||||||
key: sum([m[key] for m in metrics if m is not None]) for key in metrics[0]
|
key: sum([m[key] for m in metrics if m is not None]) for key in metrics[0]
|
||||||
@@ -404,7 +399,6 @@ class Crew(BaseModel):
|
|||||||
]
|
]
|
||||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0:
|
if len(self.agents) > 1 and len(agents_for_delegation) > 0:
|
||||||
task.tools += task.agent.get_delegation_tools(agents_for_delegation)
|
task.tools += task.agent.get_delegation_tools(agents_for_delegation)
|
||||||
task.tools += task.agent.get_delegation_tools(agents_for_delegation)
|
|
||||||
|
|
||||||
role = task.agent.role if task.agent is not None else "None"
|
role = task.agent.role if task.agent is not None else "None"
|
||||||
self._logger.log("debug", f"== Working Agent: {role}", color="bold_purple")
|
self._logger.log("debug", f"== Working Agent: {role}", color="bold_purple")
|
||||||
@@ -479,12 +473,9 @@ class Crew(BaseModel):
|
|||||||
self._finish_execution(task_output)
|
self._finish_execution(task_output)
|
||||||
|
|
||||||
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
# type: ignore # Incompatible return value type (got "tuple[str, Any]", expected "str")
|
||||||
token_usage = self._calculate_usage_metrics()
|
token_usage = self.calculate_usage_metrics()
|
||||||
|
|
||||||
return (
|
return self._format_output(task_output, token_usage), token_usage
|
||||||
self._format_output(task_output, token_usage),
|
|
||||||
token_usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
"""Create a deep copy of the Crew."""
|
"""Create a deep copy of the Crew."""
|
||||||
@@ -558,9 +549,8 @@ class Crew(BaseModel):
|
|||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
self._telemetry.end_crew(self, output)
|
self._telemetry.end_crew(self, output)
|
||||||
|
|
||||||
def _calculate_usage_metrics(
|
def calculate_usage_metrics(self) -> Dict[str, int]:
|
||||||
self,
|
"""Calculates and returns the usage metrics."""
|
||||||
) -> Dict[str, int]:
|
|
||||||
total_usage_metrics = {
|
total_usage_metrics = {
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
@@ -571,17 +561,13 @@ class Crew(BaseModel):
|
|||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
if hasattr(agent, "_token_process"):
|
if hasattr(agent, "_token_process"):
|
||||||
token_sum = agent._token_process.get_summary()
|
token_sum = agent._token_process.get_summary()
|
||||||
total_usage_metrics = {
|
for key in total_usage_metrics:
|
||||||
key: total_usage_metrics[key] + token_sum[key]
|
total_usage_metrics[key] += token_sum.get(key, 0)
|
||||||
for key in total_usage_metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.manager_agent:
|
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
||||||
token_sum = self.manager_agent._token_process.get_summary()
|
token_sum = self.manager_agent._token_process.get_summary()
|
||||||
total_usage_metrics = {
|
for key in total_usage_metrics:
|
||||||
key: total_usage_metrics[key] + token_sum[key]
|
total_usage_metrics[key] += token_sum.get(key, 0)
|
||||||
for key in total_usage_metrics
|
|
||||||
}
|
|
||||||
|
|
||||||
return total_usage_metrics
|
return total_usage_metrics
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user