Compare commits

...

2 Commits

Author SHA1 Message Date
Brandon Hancock
cbc85f97bf remove print statements 2024-12-02 13:15:35 -05:00
Brandon Hancock
fa397d47e3 v1 of fix implemented. Need to confirm with tokens. 2024-12-02 12:22:50 -05:00
4 changed files with 44 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
import os
import shutil
import subprocess
from typing import Any, List, Literal, Optional, Union, Dict
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -9,9 +9,10 @@ from crewai.agents import CacheHandler
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.cli.constants import ENV_VARS
from crewai.llm import LLM
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task
from crewai.tools import BaseTool
@@ -21,7 +22,6 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
from crewai.utilities.converter import generate_model_description
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
def mock_agent_ops_provider():

View File

@@ -1,5 +1,6 @@
import json
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -12,6 +13,7 @@ from crewai.agents.parser import (
OutputParserException,
)
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.base_tool import BaseTool
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N, Printer
from crewai.utilities.constants import TRAINING_DATA_FILE
@@ -22,6 +24,12 @@ from crewai.utilities.logger import Logger
from crewai.utilities.training_handler import CrewTrainingHandler
@dataclass
class ToolResult:
result: Any
result_as_answer: bool
class CrewAgentExecutor(CrewAgentExecutorMixin):
_logger: Logger = Logger()
@@ -33,7 +41,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
agent: BaseAgent,
prompt: dict[str, str],
max_iter: int,
tools: List[Any],
tools: List[BaseTool],
tools_names: str,
stop_words: List[str],
tools_description: str,
@@ -70,7 +78,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.iterations = 0
self.log_error_after = 3
self.have_forced_answer = False
self.name_to_tool_map = {tool.name: tool for tool in self.tools}
self.tool_name_to_tool_map: Dict[str, BaseTool] = {
tool.name: tool for tool in self.tools
}
if self.llm.stop:
self.llm.stop = list(set(self.llm.stop + self.stop))
else:
@@ -140,9 +150,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer = self._format_answer(answer)
if isinstance(formatted_answer, AgentAction):
action_result = self._use_tool(formatted_answer)
formatted_answer.text += f"\nObservation: {action_result}"
formatted_answer.result = action_result
tool_result = self._execute_tool_and_check_finality(
formatted_answer
)
formatted_answer.text += f"\nObservation: {tool_result.result}"
formatted_answer.result = tool_result.result
if tool_result.result_as_answer:
return AgentFinish(
thought="",
output=tool_result.result,
text=formatted_answer.text,
)
self._show_logs(formatted_answer)
if self.step_callback:
@@ -239,7 +257,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
)
def _use_tool(self, agent_action: AgentAction) -> Any:
def _execute_tool_and_check_finality(self, agent_action: AgentAction) -> ToolResult:
tool_usage = ToolUsage(
tools_handler=self.tools_handler,
tools=self.tools,
@@ -255,19 +273,25 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if isinstance(tool_calling, ToolUsageErrorException):
tool_result = tool_calling.message
return ToolResult(result=tool_result, result_as_answer=False)
else:
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in self.name_to_tool_map
name.casefold().strip() for name in self.tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in self.name_to_tool_map
name.casefold().strip() for name in self.tool_name_to_tool_map
]:
tool_result = tool_usage.use(tool_calling, agent_action.text)
tool = self.tool_name_to_tool_map.get(tool_calling.tool_name)
if tool:
return ToolResult(
result=tool_result, result_as_answer=tool.result_as_answer
)
else:
tool_result = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in self.tools]),
)
return tool_result
return ToolResult(result=tool_result, result_as_answer=False)
def _summarize_messages(self) -> None:
messages_groups = []
@@ -333,9 +357,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
train_iteration = self.crew._train_iteration
if agent_id in training_data and isinstance(train_iteration, int):
training_data[agent_id][train_iteration]["improved_output"] = (
result.output
)
training_data[agent_id][train_iteration][
"improved_output"
] = result.output
training_handler.save(training_data)
else:
self._logger.log(

View File

@@ -73,6 +73,7 @@ class BaseTool(BaseModel, ABC):
description=self.description,
args_schema=self.args_schema,
func=self._run,
result_as_answer=self.result_as_answer,
)
@classmethod

View File

@@ -22,6 +22,7 @@ class CrewStructuredTool:
description: str,
args_schema: type[BaseModel],
func: Callable[..., Any],
result_as_answer: bool = False,
) -> None:
"""Initialize the structured tool.
@@ -30,12 +31,14 @@ class CrewStructuredTool:
description: A description of what the tool does
args_schema: The pydantic model for the tool's arguments
func: The function to run when the tool is called
result_as_answer: Whether to return the output directly
"""
self.name = name
self.description = description
self.args_schema = args_schema
self.func = func
self._logger = Logger()
self.result_as_answer = result_as_answer
# Validate the function signature matches the schema
self._validate_function_signature()
@@ -98,6 +101,7 @@ class CrewStructuredTool:
description=description,
args_schema=schema,
func=func,
result_as_answer=return_direct,
)
@staticmethod