feat: pass response model to lite agent

This commit is contained in:
Greyson Lalonde
2025-11-05 07:13:03 -05:00
parent de2e516995
commit 7380f7b794
7 changed files with 66 additions and 37 deletions

View File

@@ -5,10 +5,18 @@ the ReAct (Reasoning and Acting) format, converting them into structured
AgentAction or AgentFinish objects.
"""
from __future__ import annotations
from dataclasses import dataclass
import re
from typing import TYPE_CHECKING
from json_repair import repair_json # type: ignore[import-untyped]
if TYPE_CHECKING:
from pydantic import BaseModel
from crewai.agents.constants import (
ACTION_INPUT_ONLY_REGEX,
ACTION_INPUT_REGEX,
@@ -42,6 +50,7 @@ class AgentFinish:
thought: str
output: str
text: str
pydantic: BaseModel | None = None # Optional structured output from response_model
class OutputParserError(Exception):
@@ -140,7 +149,7 @@ def _extract_thought(text: str) -> str:
text: The full agent output text.
Returns:
The extracted thought string.
The extracted thought string with duplicate consecutive "Thought:" prefixes removed.
"""
thought_index = text.find("\nAction")
if thought_index == -1:
@@ -149,7 +158,13 @@ def _extract_thought(text: str) -> str:
return ""
thought = text[:thought_index].strip()
# Remove any triple backticks from the thought string
return thought.replace("```", "").strip()
thought = thought.replace("```", "").strip()
thought = re.sub(r"(?i)^thought:\s*", "", thought, count=1)
thought = re.sub(r"(?i)\nthought:\s*", "\n", thought)
return thought.strip()
def _clean_action(text: str) -> str:

View File

@@ -423,7 +423,11 @@ class LiteAgent(FlowTrackable, BaseModel):
)
# Add response format instructions if specified
if self.response_format:
if (
self.response_format
and isinstance(self.llm, BaseLLM)
and not self.llm.supports_function_calling()
):
schema = generate_model_description(self.response_format)
base_prompt += self.i18n.slice("lite_agent_response_format").format(
response_format=schema
@@ -478,6 +482,7 @@ class LiteAgent(FlowTrackable, BaseModel):
callbacks=self._callbacks,
printer=self._printer,
from_agent=self,
response_model=self.response_format,
)
except Exception as e:

View File

@@ -750,15 +750,14 @@ class LLM(BaseLLM):
llm=self,
)
result = instructor_instance.to_pydantic()
structured_response = result.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
return result
self._handle_emit_call_events(
response=full_response,
@@ -941,15 +940,14 @@ class LLM(BaseLLM):
llm=self,
)
result = instructor_instance.to_pydantic()
structured_response = result.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
return result
try:
# Attempt to make the completion call, but catch context window errors
@@ -969,15 +967,14 @@ class LLM(BaseLLM):
if response_model is not None:
# When using instructor/response_model, litellm returns a Pydantic model instance
if isinstance(response, BaseModel):
structured_response = response.model_dump_json()
self._handle_emit_call_events(
response=structured_response,
response=response.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_response
return response
# --- 3) Extract response message and content (standard response)
response_message = cast(Choices, cast(ModelResponse, response).choices)[

View File

@@ -179,6 +179,14 @@ class BaseLLM(ABC):
"""
return DEFAULT_SUPPORTS_STOP_WORDS
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
def _supports_stop_words_implementation(self) -> bool:
"""Check if stop words are configured for this LLM instance.

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
import json
import logging
import os
from typing import Any, cast
@@ -47,7 +45,7 @@ class AnthropicCompletion(BaseLLM):
stop_sequences: list[str] | None = None,
stream: bool = False,
client_params: dict[str, Any] | None = None,
**kwargs,
**kwargs: Any,
):
"""Initialize Anthropic chat completion client.
@@ -110,7 +108,7 @@ class AnthropicCompletion(BaseLLM):
def call(
self,
messages: str | list[LLMMessage],
tools: list[dict] | None = None,
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -133,7 +131,7 @@ class AnthropicCompletion(BaseLLM):
try:
# Emit call started event
self._emit_call_started_event(
messages=messages, # type: ignore[arg-type]
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
@@ -143,7 +141,7 @@ class AnthropicCompletion(BaseLLM):
# Format messages for Anthropic
formatted_messages, system_message = self._format_messages_for_anthropic(
messages # type: ignore[arg-type]
messages
)
# Prepare completion parameters
@@ -181,7 +179,7 @@ class AnthropicCompletion(BaseLLM):
self,
messages: list[LLMMessage],
system_message: str | None = None,
tools: list[dict] | None = None,
tools: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for Anthropic messages API.
@@ -218,7 +216,9 @@ class AnthropicCompletion(BaseLLM):
return params
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
def _convert_tools_for_interference(
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert CrewAI tool format to Anthropic tool use format."""
anthropic_tools = []
@@ -336,17 +336,17 @@ class AnthropicCompletion(BaseLLM):
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
parsed_object = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
# Check if Claude wants to use tools
if response.content and available_functions:
@@ -394,7 +394,7 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | BaseModel:
"""Handle streaming message completion."""
if response_model:
structured_tool = {
@@ -437,17 +437,17 @@ class AnthropicCompletion(BaseLLM):
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
parsed_object = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
if final_message.content and available_functions:
tool_uses = [

View File

@@ -307,15 +307,14 @@ class OpenAICompletion(BaseLLM):
parsed_object = parsed_response.choices[0].message.parsed
if parsed_object:
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
response: ChatCompletion = self.client.chat.completions.create(**params)
@@ -412,7 +411,7 @@ class OpenAICompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | BaseModel:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
@@ -440,17 +439,16 @@ class OpenAICompletion(BaseLLM):
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
response=parsed_object.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
return parsed_object
except Exception as e:
logging.error(f"Failed to parse structured output from stream: {e}")
self._emit_call_completed_event(

View File

@@ -149,7 +149,7 @@ def handle_max_iterations_exceeded(
color="yellow",
)
if formatted_answer and hasattr(formatted_answer, "text"):
if formatted_answer and formatted_answer.text:
assistant_message = (
formatted_answer.text + f"\n{i18n.errors('force_final_answer')}"
)
@@ -291,17 +291,23 @@ def get_llm_response(
def process_llm_response(
answer: str, use_stop_words: bool
answer: str | BaseModel, use_stop_words: bool
) -> AgentAction | AgentFinish:
"""Process the LLM response and format it into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
answer: The raw response from the LLM (string) or structured output (BaseModel)
use_stop_words: Whether to use stop words in the LLM call
Returns:
Either an AgentAction or AgentFinish
"""
if isinstance(answer, BaseModel):
json_output = answer.model_dump_json()
return AgentFinish(
thought="", output=json_output, text=json_output, pydantic=answer
)
if not use_stop_words:
try:
# Preliminary parsing to check for errors.