Compare commits

...

2 Commits

Author SHA1 Message Date
Brandon Hancock
cbc85f97bf remove print statements 2024-12-02 13:15:35 -05:00
Brandon Hancock
fa397d47e3 v1 of fix implemented. Need to confirm with tokens. 2024-12-02 12:22:50 -05:00
4 changed files with 44 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
import os import os
import shutil import shutil
import subprocess import subprocess
from typing import Any, List, Literal, Optional, Union, Dict from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -9,9 +9,10 @@ from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS from crewai.cli.constants import ENV_VARS
from crewai.llm import LLM
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task from crewai.task import Task
from crewai.tools import BaseTool from crewai.tools import BaseTool
@@ -21,7 +22,6 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
from crewai.utilities.converter import generate_model_description from crewai.utilities.converter import generate_model_description
from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
def mock_agent_ops_provider(): def mock_agent_ops_provider():

View File

@@ -1,5 +1,6 @@
import json import json
import re import re
from dataclasses import dataclass
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -12,6 +13,7 @@ from crewai.agents.parser import (
OutputParserException, OutputParserException,
) )
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.base_tool import BaseTool
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N, Printer from crewai.utilities import I18N, Printer
from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.constants import TRAINING_DATA_FILE
@@ -22,6 +24,12 @@ from crewai.utilities.logger import Logger
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
@dataclass
class ToolResult:
result: Any
result_as_answer: bool
class CrewAgentExecutor(CrewAgentExecutorMixin): class CrewAgentExecutor(CrewAgentExecutorMixin):
_logger: Logger = Logger() _logger: Logger = Logger()
@@ -33,7 +41,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
agent: BaseAgent, agent: BaseAgent,
prompt: dict[str, str], prompt: dict[str, str],
max_iter: int, max_iter: int,
tools: List[Any], tools: List[BaseTool],
tools_names: str, tools_names: str,
stop_words: List[str], stop_words: List[str],
tools_description: str, tools_description: str,
@@ -70,7 +78,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.iterations = 0 self.iterations = 0
self.log_error_after = 3 self.log_error_after = 3
self.have_forced_answer = False self.have_forced_answer = False
self.name_to_tool_map = {tool.name: tool for tool in self.tools} self.tool_name_to_tool_map: Dict[str, BaseTool] = {
tool.name: tool for tool in self.tools
}
if self.llm.stop: if self.llm.stop:
self.llm.stop = list(set(self.llm.stop + self.stop)) self.llm.stop = list(set(self.llm.stop + self.stop))
else: else:
@@ -140,9 +150,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer = self._format_answer(answer) formatted_answer = self._format_answer(answer)
if isinstance(formatted_answer, AgentAction): if isinstance(formatted_answer, AgentAction):
action_result = self._use_tool(formatted_answer) tool_result = self._execute_tool_and_check_finality(
formatted_answer.text += f"\nObservation: {action_result}" formatted_answer
formatted_answer.result = action_result )
formatted_answer.text += f"\nObservation: {tool_result.result}"
formatted_answer.result = tool_result.result
if tool_result.result_as_answer:
return AgentFinish(
thought="",
output=tool_result.result,
text=formatted_answer.text,
)
self._show_logs(formatted_answer) self._show_logs(formatted_answer)
if self.step_callback: if self.step_callback:
@@ -239,7 +257,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n" content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
) )
def _use_tool(self, agent_action: AgentAction) -> Any: def _execute_tool_and_check_finality(self, agent_action: AgentAction) -> ToolResult:
tool_usage = ToolUsage( tool_usage = ToolUsage(
tools_handler=self.tools_handler, tools_handler=self.tools_handler,
tools=self.tools, tools=self.tools,
@@ -255,19 +273,25 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if isinstance(tool_calling, ToolUsageErrorException): if isinstance(tool_calling, ToolUsageErrorException):
tool_result = tool_calling.message tool_result = tool_calling.message
return ToolResult(result=tool_result, result_as_answer=False)
else: else:
if tool_calling.tool_name.casefold().strip() in [ if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in self.name_to_tool_map name.casefold().strip() for name in self.tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [ ] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in self.name_to_tool_map name.casefold().strip() for name in self.tool_name_to_tool_map
]: ]:
tool_result = tool_usage.use(tool_calling, agent_action.text) tool_result = tool_usage.use(tool_calling, agent_action.text)
tool = self.tool_name_to_tool_map.get(tool_calling.tool_name)
if tool:
return ToolResult(
result=tool_result, result_as_answer=tool.result_as_answer
)
else: else:
tool_result = self._i18n.errors("wrong_tool_name").format( tool_result = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name, tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in self.tools]), tools=", ".join([tool.name.casefold() for tool in self.tools]),
) )
return tool_result return ToolResult(result=tool_result, result_as_answer=False)
def _summarize_messages(self) -> None: def _summarize_messages(self) -> None:
messages_groups = [] messages_groups = []
@@ -333,9 +357,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.crew is not None and hasattr(self.crew, "_train_iteration"): if self.crew is not None and hasattr(self.crew, "_train_iteration"):
train_iteration = self.crew._train_iteration train_iteration = self.crew._train_iteration
if agent_id in training_data and isinstance(train_iteration, int): if agent_id in training_data and isinstance(train_iteration, int):
training_data[agent_id][train_iteration]["improved_output"] = ( training_data[agent_id][train_iteration][
result.output "improved_output"
) ] = result.output
training_handler.save(training_data) training_handler.save(training_data)
else: else:
self._logger.log( self._logger.log(

View File

@@ -73,6 +73,7 @@ class BaseTool(BaseModel, ABC):
description=self.description, description=self.description,
args_schema=self.args_schema, args_schema=self.args_schema,
func=self._run, func=self._run,
result_as_answer=self.result_as_answer,
) )
@classmethod @classmethod

View File

@@ -22,6 +22,7 @@ class CrewStructuredTool:
description: str, description: str,
args_schema: type[BaseModel], args_schema: type[BaseModel],
func: Callable[..., Any], func: Callable[..., Any],
result_as_answer: bool = False,
) -> None: ) -> None:
"""Initialize the structured tool. """Initialize the structured tool.
@@ -30,12 +31,14 @@ class CrewStructuredTool:
description: A description of what the tool does description: A description of what the tool does
args_schema: The pydantic model for the tool's arguments args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called func: The function to run when the tool is called
result_as_answer: Whether to return the output directly
""" """
self.name = name self.name = name
self.description = description self.description = description
self.args_schema = args_schema self.args_schema = args_schema
self.func = func self.func = func
self._logger = Logger() self._logger = Logger()
self.result_as_answer = result_as_answer
# Validate the function signature matches the schema # Validate the function signature matches the schema
self._validate_function_signature() self._validate_function_signature()
@@ -98,6 +101,7 @@ class CrewStructuredTool:
description=description, description=description,
args_schema=schema, args_schema=schema,
func=func, func=func,
result_as_answer=return_direct,
) )
@staticmethod @staticmethod