Adding new LLM class

This commit is contained in:
João Moura
2024-09-23 03:59:05 -03:00
parent 59e51f18fd
commit a19a4a5556
9 changed files with 124 additions and 93 deletions

View File

@@ -13,7 +13,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
from crewai.utilities.logger import Logger
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.llm import LLM
from crewai.agents.parser import (
AgentAction,
AgentFinish,
@@ -104,23 +103,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
try:
while not isinstance(formatted_answer, AgentFinish):
if not self.request_within_rpm_limit or self.request_within_rpm_limit():
if isinstance(self.llm, str):
llm = LLM(
model=self.llm,
stop=self.stop if self.use_stop_words else None,
callbacks=self.callbacks,
)
elif isinstance(self.llm, LLM):
llm = self.llm
else:
llm = LLM(
model=self.llm.model,
provider=getattr(self.llm, "provider", "litellm"),
stop=self.stop if self.use_stop_words else None,
callbacks=self.callbacks,
**getattr(self.llm, "llm_kwargs", {}),
)
answer = llm.call(self.messages)
answer = self.llm.call(
self.messages,
callbacks=self.callbacks,
)
if not self.use_stop_words:
try:
@@ -139,6 +125,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
action_result = self._use_tool(formatted_answer)
formatted_answer.text += f"\nObservation: {action_result}"
formatted_answer.result = action_result
print("formatted_answer", formatted_answer)
self._show_logs(formatted_answer)
if self.step_callback:
@@ -194,7 +181,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if isinstance(formatted_answer, AgentAction):
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
formatted_json = json.dumps(
json.loads(formatted_answer.tool_input),
formatted_answer.tool_input,
indent=2,
ensure_ascii=False,
)
@@ -253,16 +240,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return tool_result
def _summarize_messages(self) -> None:
if isinstance(self.llm, str):
llm = LLM(model=self.llm)
elif isinstance(self.llm, LLM):
llm = self.llm
else:
llm = LLM(
model=self.llm.model,
provider=getattr(self.llm, "provider", "litellm"),
**getattr(self.llm, "llm_kwargs", {}),
)
messages_groups = []
for message in self.messages:
@@ -272,7 +249,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
summarized_contents = []
for group in messages_groups:
summary = llm.call(
summary = self.llm.call(
[
self._format_msg(
self._i18n.slices("summarizer_system_message"), role="system"
@@ -280,7 +257,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._format_msg(
self._i18n.errors("sumamrize_instruction").format(group=group),
),
]
],
callbacks=self.callbacks,
)
summarized_contents.append(summary)