mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: pass response model to lite agent
This commit is contained in:
@@ -5,10 +5,18 @@ the ReAct (Reasoning and Acting) format, converting them into structured
|
|||||||
AgentAction or AgentFinish objects.
|
AgentAction or AgentFinish objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from json_repair import repair_json # type: ignore[import-untyped]
|
from json_repair import repair_json # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.agents.constants import (
|
from crewai.agents.constants import (
|
||||||
ACTION_INPUT_ONLY_REGEX,
|
ACTION_INPUT_ONLY_REGEX,
|
||||||
ACTION_INPUT_REGEX,
|
ACTION_INPUT_REGEX,
|
||||||
@@ -42,6 +50,7 @@ class AgentFinish:
|
|||||||
thought: str
|
thought: str
|
||||||
output: str
|
output: str
|
||||||
text: str
|
text: str
|
||||||
|
pydantic: BaseModel | None = None # Optional structured output from response_model
|
||||||
|
|
||||||
|
|
||||||
class OutputParserError(Exception):
|
class OutputParserError(Exception):
|
||||||
@@ -140,7 +149,7 @@ def _extract_thought(text: str) -> str:
|
|||||||
text: The full agent output text.
|
text: The full agent output text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The extracted thought string.
|
The extracted thought string with duplicate consecutive "Thought:" prefixes removed.
|
||||||
"""
|
"""
|
||||||
thought_index = text.find("\nAction")
|
thought_index = text.find("\nAction")
|
||||||
if thought_index == -1:
|
if thought_index == -1:
|
||||||
@@ -149,7 +158,13 @@ def _extract_thought(text: str) -> str:
|
|||||||
return ""
|
return ""
|
||||||
thought = text[:thought_index].strip()
|
thought = text[:thought_index].strip()
|
||||||
# Remove any triple backticks from the thought string
|
# Remove any triple backticks from the thought string
|
||||||
return thought.replace("```", "").strip()
|
thought = thought.replace("```", "").strip()
|
||||||
|
|
||||||
|
thought = re.sub(r"(?i)^thought:\s*", "", thought, count=1)
|
||||||
|
|
||||||
|
thought = re.sub(r"(?i)\nthought:\s*", "\n", thought)
|
||||||
|
|
||||||
|
return thought.strip()
|
||||||
|
|
||||||
|
|
||||||
def _clean_action(text: str) -> str:
|
def _clean_action(text: str) -> str:
|
||||||
|
|||||||
@@ -423,7 +423,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add response format instructions if specified
|
# Add response format instructions if specified
|
||||||
if self.response_format:
|
if (
|
||||||
|
self.response_format
|
||||||
|
and isinstance(self.llm, BaseLLM)
|
||||||
|
and not self.llm.supports_function_calling()
|
||||||
|
):
|
||||||
schema = generate_model_description(self.response_format)
|
schema = generate_model_description(self.response_format)
|
||||||
base_prompt += self.i18n.slice("lite_agent_response_format").format(
|
base_prompt += self.i18n.slice("lite_agent_response_format").format(
|
||||||
response_format=schema
|
response_format=schema
|
||||||
@@ -478,6 +482,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
callbacks=self._callbacks,
|
callbacks=self._callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
from_agent=self,
|
from_agent=self,
|
||||||
|
response_model=self.response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -750,15 +750,14 @@ class LLM(BaseLLM):
|
|||||||
llm=self,
|
llm=self,
|
||||||
)
|
)
|
||||||
result = instructor_instance.to_pydantic()
|
result = instructor_instance.to_pydantic()
|
||||||
structured_response = result.model_dump_json()
|
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=structured_response,
|
response=result.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_response
|
return result
|
||||||
|
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=full_response,
|
response=full_response,
|
||||||
@@ -941,15 +940,14 @@ class LLM(BaseLLM):
|
|||||||
llm=self,
|
llm=self,
|
||||||
)
|
)
|
||||||
result = instructor_instance.to_pydantic()
|
result = instructor_instance.to_pydantic()
|
||||||
structured_response = result.model_dump_json()
|
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=structured_response,
|
response=result.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_response
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Attempt to make the completion call, but catch context window errors
|
# Attempt to make the completion call, but catch context window errors
|
||||||
@@ -969,15 +967,14 @@ class LLM(BaseLLM):
|
|||||||
if response_model is not None:
|
if response_model is not None:
|
||||||
# When using instructor/response_model, litellm returns a Pydantic model instance
|
# When using instructor/response_model, litellm returns a Pydantic model instance
|
||||||
if isinstance(response, BaseModel):
|
if isinstance(response, BaseModel):
|
||||||
structured_response = response.model_dump_json()
|
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=structured_response,
|
response=response.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_response
|
return response
|
||||||
|
|
||||||
# --- 3) Extract response message and content (standard response)
|
# --- 3) Extract response message and content (standard response)
|
||||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||||
|
|||||||
@@ -179,6 +179,14 @@ class BaseLLM(ABC):
|
|||||||
"""
|
"""
|
||||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports function calling, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
def _supports_stop_words_implementation(self) -> bool:
|
def _supports_stop_words_implementation(self) -> bool:
|
||||||
"""Check if stop words are configured for this LLM instance.
|
"""Check if stop words are configured for this LLM instance.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@@ -47,7 +45,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
stop_sequences: list[str] | None = None,
|
stop_sequences: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
client_params: dict[str, Any] | None = None,
|
client_params: dict[str, Any] | None = None,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Anthropic chat completion client.
|
"""Initialize Anthropic chat completion client.
|
||||||
|
|
||||||
@@ -110,7 +108,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: str | list[LLMMessage],
|
messages: str | list[LLMMessage],
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
@@ -133,7 +131,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
try:
|
try:
|
||||||
# Emit call started event
|
# Emit call started event
|
||||||
self._emit_call_started_event(
|
self._emit_call_started_event(
|
||||||
messages=messages, # type: ignore[arg-type]
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
@@ -143,7 +141,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
# Format messages for Anthropic
|
# Format messages for Anthropic
|
||||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||||
messages # type: ignore[arg-type]
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare completion parameters
|
# Prepare completion parameters
|
||||||
@@ -181,7 +179,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
messages: list[LLMMessage],
|
messages: list[LLMMessage],
|
||||||
system_message: str | None = None,
|
system_message: str | None = None,
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare parameters for Anthropic messages API.
|
"""Prepare parameters for Anthropic messages API.
|
||||||
|
|
||||||
@@ -218,7 +216,9 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
def _convert_tools_for_interference(
|
||||||
|
self, tools: list[dict[str, Any]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""Convert CrewAI tool format to Anthropic tool use format."""
|
"""Convert CrewAI tool format to Anthropic tool use format."""
|
||||||
anthropic_tools = []
|
anthropic_tools = []
|
||||||
|
|
||||||
@@ -336,17 +336,17 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
]
|
]
|
||||||
if tool_uses and tool_uses[0].name == "structured_output":
|
if tool_uses and tool_uses[0].name == "structured_output":
|
||||||
structured_data = tool_uses[0].input
|
structured_data = tool_uses[0].input
|
||||||
structured_json = json.dumps(structured_data)
|
parsed_object = response_model.model_validate(structured_data)
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return structured_json
|
return parsed_object
|
||||||
|
|
||||||
# Check if Claude wants to use tools
|
# Check if Claude wants to use tools
|
||||||
if response.content and available_functions:
|
if response.content and available_functions:
|
||||||
@@ -394,7 +394,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | BaseModel:
|
||||||
"""Handle streaming message completion."""
|
"""Handle streaming message completion."""
|
||||||
if response_model:
|
if response_model:
|
||||||
structured_tool = {
|
structured_tool = {
|
||||||
@@ -437,17 +437,17 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
]
|
]
|
||||||
if tool_uses and tool_uses[0].name == "structured_output":
|
if tool_uses and tool_uses[0].name == "structured_output":
|
||||||
structured_data = tool_uses[0].input
|
structured_data = tool_uses[0].input
|
||||||
structured_json = json.dumps(structured_data)
|
parsed_object = response_model.model_validate(structured_data)
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return structured_json
|
return parsed_object
|
||||||
|
|
||||||
if final_message.content and available_functions:
|
if final_message.content and available_functions:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
|
|||||||
@@ -307,15 +307,14 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
parsed_object = parsed_response.choices[0].message.parsed
|
parsed_object = parsed_response.choices[0].message.parsed
|
||||||
if parsed_object:
|
if parsed_object:
|
||||||
structured_json = parsed_object.model_dump_json()
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
return structured_json
|
return parsed_object
|
||||||
|
|
||||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
@@ -412,7 +411,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | BaseModel:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
tool_calls = {}
|
tool_calls = {}
|
||||||
@@ -440,17 +439,16 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||||
structured_json = parsed_object.model_dump_json()
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=parsed_object.model_dump_json(),
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return structured_json
|
return parsed_object
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ def handle_max_iterations_exceeded(
|
|||||||
color="yellow",
|
color="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
if formatted_answer and hasattr(formatted_answer, "text"):
|
if formatted_answer and formatted_answer.text:
|
||||||
assistant_message = (
|
assistant_message = (
|
||||||
formatted_answer.text + f"\n{i18n.errors('force_final_answer')}"
|
formatted_answer.text + f"\n{i18n.errors('force_final_answer')}"
|
||||||
)
|
)
|
||||||
@@ -291,17 +291,23 @@ def get_llm_response(
|
|||||||
|
|
||||||
|
|
||||||
def process_llm_response(
|
def process_llm_response(
|
||||||
answer: str, use_stop_words: bool
|
answer: str | BaseModel, use_stop_words: bool
|
||||||
) -> AgentAction | AgentFinish:
|
) -> AgentAction | AgentFinish:
|
||||||
"""Process the LLM response and format it into an AgentAction or AgentFinish.
|
"""Process the LLM response and format it into an AgentAction or AgentFinish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
answer: The raw response from the LLM
|
answer: The raw response from the LLM (string) or structured output (BaseModel)
|
||||||
use_stop_words: Whether to use stop words in the LLM call
|
use_stop_words: Whether to use stop words in the LLM call
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either an AgentAction or AgentFinish
|
Either an AgentAction or AgentFinish
|
||||||
"""
|
"""
|
||||||
|
if isinstance(answer, BaseModel):
|
||||||
|
json_output = answer.model_dump_json()
|
||||||
|
return AgentFinish(
|
||||||
|
thought="", output=json_output, text=json_output, pydantic=answer
|
||||||
|
)
|
||||||
|
|
||||||
if not use_stop_words:
|
if not use_stop_words:
|
||||||
try:
|
try:
|
||||||
# Preliminary parsing to check for errors.
|
# Preliminary parsing to check for errors.
|
||||||
|
|||||||
Reference in New Issue
Block a user