mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-17 21:08:29 +00:00
Compare commits
2 Commits
tm-add-tas
...
bugfix/res
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cbc85f97bf | ||
|
|
fa397d47e3 |
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any, List, Literal, Optional, Union, Dict
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||||
|
|
||||||
@@ -9,9 +9,10 @@ from crewai.agents import CacheHandler
|
|||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
from crewai.cli.constants import ENV_VARS
|
from crewai.cli.constants import ENV_VARS
|
||||||
from crewai.llm import LLM
|
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
|
from crewai.llm import LLM
|
||||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
@@ -21,7 +22,6 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
|
|||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
|
||||||
|
|
||||||
|
|
||||||
def mock_agent_ops_provider():
|
def mock_agent_ops_provider():
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
@@ -12,6 +13,7 @@ from crewai.agents.parser import (
|
|||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||||
from crewai.utilities import I18N, Printer
|
from crewai.utilities import I18N, Printer
|
||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
@@ -22,6 +24,12 @@ from crewai.utilities.logger import Logger
|
|||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
result: Any
|
||||||
|
result_as_answer: bool
|
||||||
|
|
||||||
|
|
||||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||||
_logger: Logger = Logger()
|
_logger: Logger = Logger()
|
||||||
|
|
||||||
@@ -33,7 +41,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
agent: BaseAgent,
|
agent: BaseAgent,
|
||||||
prompt: dict[str, str],
|
prompt: dict[str, str],
|
||||||
max_iter: int,
|
max_iter: int,
|
||||||
tools: List[Any],
|
tools: List[BaseTool],
|
||||||
tools_names: str,
|
tools_names: str,
|
||||||
stop_words: List[str],
|
stop_words: List[str],
|
||||||
tools_description: str,
|
tools_description: str,
|
||||||
@@ -70,7 +78,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.iterations = 0
|
self.iterations = 0
|
||||||
self.log_error_after = 3
|
self.log_error_after = 3
|
||||||
self.have_forced_answer = False
|
self.have_forced_answer = False
|
||||||
self.name_to_tool_map = {tool.name: tool for tool in self.tools}
|
self.tool_name_to_tool_map: Dict[str, BaseTool] = {
|
||||||
|
tool.name: tool for tool in self.tools
|
||||||
|
}
|
||||||
if self.llm.stop:
|
if self.llm.stop:
|
||||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
self.llm.stop = list(set(self.llm.stop + self.stop))
|
||||||
else:
|
else:
|
||||||
@@ -140,9 +150,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
formatted_answer = self._format_answer(answer)
|
formatted_answer = self._format_answer(answer)
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
action_result = self._use_tool(formatted_answer)
|
tool_result = self._execute_tool_and_check_finality(
|
||||||
formatted_answer.text += f"\nObservation: {action_result}"
|
formatted_answer
|
||||||
formatted_answer.result = action_result
|
)
|
||||||
|
formatted_answer.text += f"\nObservation: {tool_result.result}"
|
||||||
|
formatted_answer.result = tool_result.result
|
||||||
|
if tool_result.result_as_answer:
|
||||||
|
return AgentFinish(
|
||||||
|
thought="",
|
||||||
|
output=tool_result.result,
|
||||||
|
text=formatted_answer.text,
|
||||||
|
)
|
||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
|
|
||||||
if self.step_callback:
|
if self.step_callback:
|
||||||
@@ -239,7 +257,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _use_tool(self, agent_action: AgentAction) -> Any:
|
def _execute_tool_and_check_finality(self, agent_action: AgentAction) -> ToolResult:
|
||||||
tool_usage = ToolUsage(
|
tool_usage = ToolUsage(
|
||||||
tools_handler=self.tools_handler,
|
tools_handler=self.tools_handler,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
@@ -255,19 +273,25 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
if isinstance(tool_calling, ToolUsageErrorException):
|
if isinstance(tool_calling, ToolUsageErrorException):
|
||||||
tool_result = tool_calling.message
|
tool_result = tool_calling.message
|
||||||
|
return ToolResult(result=tool_result, result_as_answer=False)
|
||||||
else:
|
else:
|
||||||
if tool_calling.tool_name.casefold().strip() in [
|
if tool_calling.tool_name.casefold().strip() in [
|
||||||
name.casefold().strip() for name in self.name_to_tool_map
|
name.casefold().strip() for name in self.tool_name_to_tool_map
|
||||||
] or tool_calling.tool_name.casefold().replace("_", " ") in [
|
] or tool_calling.tool_name.casefold().replace("_", " ") in [
|
||||||
name.casefold().strip() for name in self.name_to_tool_map
|
name.casefold().strip() for name in self.tool_name_to_tool_map
|
||||||
]:
|
]:
|
||||||
tool_result = tool_usage.use(tool_calling, agent_action.text)
|
tool_result = tool_usage.use(tool_calling, agent_action.text)
|
||||||
|
tool = self.tool_name_to_tool_map.get(tool_calling.tool_name)
|
||||||
|
if tool:
|
||||||
|
return ToolResult(
|
||||||
|
result=tool_result, result_as_answer=tool.result_as_answer
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tool_result = self._i18n.errors("wrong_tool_name").format(
|
tool_result = self._i18n.errors("wrong_tool_name").format(
|
||||||
tool=tool_calling.tool_name,
|
tool=tool_calling.tool_name,
|
||||||
tools=", ".join([tool.name.casefold() for tool in self.tools]),
|
tools=", ".join([tool.name.casefold() for tool in self.tools]),
|
||||||
)
|
)
|
||||||
return tool_result
|
return ToolResult(result=tool_result, result_as_answer=False)
|
||||||
|
|
||||||
def _summarize_messages(self) -> None:
|
def _summarize_messages(self) -> None:
|
||||||
messages_groups = []
|
messages_groups = []
|
||||||
@@ -333,9 +357,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
if self.crew is not None and hasattr(self.crew, "_train_iteration"):
|
||||||
train_iteration = self.crew._train_iteration
|
train_iteration = self.crew._train_iteration
|
||||||
if agent_id in training_data and isinstance(train_iteration, int):
|
if agent_id in training_data and isinstance(train_iteration, int):
|
||||||
training_data[agent_id][train_iteration]["improved_output"] = (
|
training_data[agent_id][train_iteration][
|
||||||
result.output
|
"improved_output"
|
||||||
)
|
] = result.output
|
||||||
training_handler.save(training_data)
|
training_handler.save(training_data)
|
||||||
else:
|
else:
|
||||||
self._logger.log(
|
self._logger.log(
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ class BaseTool(BaseModel, ABC):
|
|||||||
description=self.description,
|
description=self.description,
|
||||||
args_schema=self.args_schema,
|
args_schema=self.args_schema,
|
||||||
func=self._run,
|
func=self._run,
|
||||||
|
result_as_answer=self.result_as_answer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class CrewStructuredTool:
|
|||||||
description: str,
|
description: str,
|
||||||
args_schema: type[BaseModel],
|
args_schema: type[BaseModel],
|
||||||
func: Callable[..., Any],
|
func: Callable[..., Any],
|
||||||
|
result_as_answer: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the structured tool.
|
"""Initialize the structured tool.
|
||||||
|
|
||||||
@@ -30,12 +31,14 @@ class CrewStructuredTool:
|
|||||||
description: A description of what the tool does
|
description: A description of what the tool does
|
||||||
args_schema: The pydantic model for the tool's arguments
|
args_schema: The pydantic model for the tool's arguments
|
||||||
func: The function to run when the tool is called
|
func: The function to run when the tool is called
|
||||||
|
result_as_answer: Whether to return the output directly
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
self.args_schema = args_schema
|
self.args_schema = args_schema
|
||||||
self.func = func
|
self.func = func
|
||||||
self._logger = Logger()
|
self._logger = Logger()
|
||||||
|
self.result_as_answer = result_as_answer
|
||||||
|
|
||||||
# Validate the function signature matches the schema
|
# Validate the function signature matches the schema
|
||||||
self._validate_function_signature()
|
self._validate_function_signature()
|
||||||
@@ -98,6 +101,7 @@ class CrewStructuredTool:
|
|||||||
description=description,
|
description=description,
|
||||||
args_schema=schema,
|
args_schema=schema,
|
||||||
func=func,
|
func=func,
|
||||||
|
result_as_answer=return_direct,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user