mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: pass response model to lite agent
This commit is contained in:
@@ -5,10 +5,18 @@ the ReAct (Reasoning and Acting) format, converting them into structured
|
||||
AgentAction or AgentFinish objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from json_repair import repair_json # type: ignore[import-untyped]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.agents.constants import (
|
||||
ACTION_INPUT_ONLY_REGEX,
|
||||
ACTION_INPUT_REGEX,
|
||||
@@ -42,6 +50,7 @@ class AgentFinish:
|
||||
thought: str
|
||||
output: str
|
||||
text: str
|
||||
pydantic: BaseModel | None = None # Optional structured output from response_model
|
||||
|
||||
|
||||
class OutputParserError(Exception):
|
||||
@@ -140,7 +149,7 @@ def _extract_thought(text: str) -> str:
|
||||
text: The full agent output text.
|
||||
|
||||
Returns:
|
||||
The extracted thought string.
|
||||
The extracted thought string with duplicate consecutive "Thought:" prefixes removed.
|
||||
"""
|
||||
thought_index = text.find("\nAction")
|
||||
if thought_index == -1:
|
||||
@@ -149,7 +158,13 @@ def _extract_thought(text: str) -> str:
|
||||
return ""
|
||||
thought = text[:thought_index].strip()
|
||||
# Remove any triple backticks from the thought string
|
||||
return thought.replace("```", "").strip()
|
||||
thought = thought.replace("```", "").strip()
|
||||
|
||||
thought = re.sub(r"(?i)^thought:\s*", "", thought, count=1)
|
||||
|
||||
thought = re.sub(r"(?i)\nthought:\s*", "\n", thought)
|
||||
|
||||
return thought.strip()
|
||||
|
||||
|
||||
def _clean_action(text: str) -> str:
|
||||
|
||||
@@ -423,7 +423,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
# Add response format instructions if specified
|
||||
if self.response_format:
|
||||
if (
|
||||
self.response_format
|
||||
and isinstance(self.llm, BaseLLM)
|
||||
and not self.llm.supports_function_calling()
|
||||
):
|
||||
schema = generate_model_description(self.response_format)
|
||||
base_prompt += self.i18n.slice("lite_agent_response_format").format(
|
||||
response_format=schema
|
||||
@@ -478,6 +482,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
from_agent=self,
|
||||
response_model=self.response_format,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -750,15 +750,14 @@ class LLM(BaseLLM):
|
||||
llm=self,
|
||||
)
|
||||
result = instructor_instance.to_pydantic()
|
||||
structured_response = result.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
response=result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
return result
|
||||
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
@@ -941,15 +940,14 @@ class LLM(BaseLLM):
|
||||
llm=self,
|
||||
)
|
||||
result = instructor_instance.to_pydantic()
|
||||
structured_response = result.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
response=result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
return result
|
||||
|
||||
try:
|
||||
# Attempt to make the completion call, but catch context window errors
|
||||
@@ -969,15 +967,14 @@ class LLM(BaseLLM):
|
||||
if response_model is not None:
|
||||
# When using instructor/response_model, litellm returns a Pydantic model instance
|
||||
if isinstance(response, BaseModel):
|
||||
structured_response = response.model_dump_json()
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
response=response.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_response
|
||||
return response
|
||||
|
||||
# --- 3) Extract response message and content (standard response)
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
|
||||
@@ -179,6 +179,14 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||
|
||||
@abstractmethod
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
|
||||
def _supports_stop_words_implementation(self) -> bool:
|
||||
"""Check if stop words are configured for this LLM instance.
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, cast
|
||||
@@ -47,7 +45,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
|
||||
@@ -110,7 +108,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -133,7 +131,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages, # type: ignore[arg-type]
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
@@ -143,7 +141,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
# Format messages for Anthropic
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages # type: ignore[arg-type]
|
||||
messages
|
||||
)
|
||||
|
||||
# Prepare completion parameters
|
||||
@@ -181,7 +179,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
system_message: str | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for Anthropic messages API.
|
||||
|
||||
@@ -218,7 +216,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
def _convert_tools_for_interference(
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert CrewAI tool format to Anthropic tool use format."""
|
||||
anthropic_tools = []
|
||||
|
||||
@@ -336,17 +336,17 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
parsed_object = response_model.model_validate(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return parsed_object
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content and available_functions:
|
||||
@@ -394,7 +394,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Handle streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
@@ -437,17 +437,17 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
parsed_object = response_model.model_validate(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return parsed_object
|
||||
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
|
||||
@@ -307,15 +307,14 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
parsed_object = parsed_response.choices[0].message.parsed
|
||||
if parsed_object:
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
|
||||
@@ -412,7 +411,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | BaseModel:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
@@ -440,17 +439,16 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=parsed_object.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return parsed_object
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||
self._emit_call_completed_event(
|
||||
|
||||
@@ -149,7 +149,7 @@ def handle_max_iterations_exceeded(
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
if formatted_answer and hasattr(formatted_answer, "text"):
|
||||
if formatted_answer and formatted_answer.text:
|
||||
assistant_message = (
|
||||
formatted_answer.text + f"\n{i18n.errors('force_final_answer')}"
|
||||
)
|
||||
@@ -291,17 +291,23 @@ def get_llm_response(
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
answer: str | BaseModel, use_stop_words: bool
|
||||
) -> AgentAction | AgentFinish:
|
||||
"""Process the LLM response and format it into an AgentAction or AgentFinish.
|
||||
|
||||
Args:
|
||||
answer: The raw response from the LLM
|
||||
answer: The raw response from the LLM (string) or structured output (BaseModel)
|
||||
use_stop_words: Whether to use stop words in the LLM call
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish
|
||||
"""
|
||||
if isinstance(answer, BaseModel):
|
||||
json_output = answer.model_dump_json()
|
||||
return AgentFinish(
|
||||
thought="", output=json_output, text=json_output, pydantic=answer
|
||||
)
|
||||
|
||||
if not use_stop_words:
|
||||
try:
|
||||
# Preliminary parsing to check for errors.
|
||||
|
||||
Reference in New Issue
Block a user