mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 15:48:23 +00:00
Compare commits
3 Commits
devin/1768
...
devin/1757
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61e99a61f0 | ||
|
|
a5617cbfff | ||
|
|
934c63ede1 |
@@ -133,6 +133,10 @@ select = [
|
|||||||
]
|
]
|
||||||
ignore = ["E501"] # ignore line too long
|
ignore = ["E501"] # ignore line too long
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
|
||||||
|
"src/crewai/lite_agent.py" = ["PERF203"] # Allow try-except in loop for LLM parsing
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
exclude = ["src/crewai/cli/templates", "tests"]
|
exclude = ["src/crewai/cli/templates", "tests"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
try:
|
|
||||||
from typing import Self
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
@@ -27,8 +17,8 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
InstanceOf,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
model_validator,
|
|
||||||
field_validator,
|
field_validator,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
@@ -39,12 +29,18 @@ from crewai.agents.parser import (
|
|||||||
AgentFinish,
|
AgentFinish,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.agent_events import (
|
||||||
|
LiteAgentExecutionCompletedEvent,
|
||||||
|
LiteAgentExecutionErrorEvent,
|
||||||
|
LiteAgentExecutionStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||||
from crewai.flow.flow_trackable import FlowTrackable
|
from crewai.flow.flow_trackable import FlowTrackable
|
||||||
from crewai.llm import LLM, BaseLLM
|
from crewai.llm import LLM, BaseLLM
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.utilities import I18N
|
from crewai.utilities import I18N
|
||||||
from crewai.utilities.guardrail import process_guardrail
|
|
||||||
from crewai.utilities.agent_utils import (
|
from crewai.utilities.agent_utils import (
|
||||||
enforce_rpm_limit,
|
enforce_rpm_limit,
|
||||||
format_message_for_llm,
|
format_message_for_llm,
|
||||||
@@ -61,15 +57,8 @@ from crewai.utilities.agent_utils import (
|
|||||||
process_llm_response,
|
process_llm_response,
|
||||||
render_text_description_and_args,
|
render_text_description_and_args,
|
||||||
)
|
)
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import convert_to_model, generate_model_description
|
||||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
from crewai.utilities.guardrail import process_guardrail
|
||||||
from crewai.events.types.agent_events import (
|
|
||||||
LiteAgentExecutionCompletedEvent,
|
|
||||||
LiteAgentExecutionErrorEvent,
|
|
||||||
LiteAgentExecutionStartedEvent,
|
|
||||||
)
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
|
||||||
|
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
@@ -82,15 +71,15 @@ class LiteAgentOutput(BaseModel):
|
|||||||
model_config = {"arbitrary_types_allowed": True}
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
raw: str = Field(description="Raw output of the agent", default="")
|
raw: str = Field(description="Raw output of the agent", default="")
|
||||||
pydantic: Optional[BaseModel] = Field(
|
pydantic: BaseModel | None = Field(
|
||||||
description="Pydantic output of the agent", default=None
|
description="Pydantic output of the agent", default=None
|
||||||
)
|
)
|
||||||
agent_role: str = Field(description="Role of the agent that produced this output")
|
agent_role: str = Field(description="Role of the agent that produced this output")
|
||||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
usage_metrics: dict[str, Any] | None = Field(
|
||||||
description="Token usage metrics for this execution", default=None
|
description="Token usage metrics for this execution", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Convert pydantic_output to a dictionary."""
|
"""Convert pydantic_output to a dictionary."""
|
||||||
if self.pydantic:
|
if self.pydantic:
|
||||||
return self.pydantic.model_dump()
|
return self.pydantic.model_dump()
|
||||||
@@ -130,10 +119,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
role: str = Field(description="Role of the agent")
|
role: str = Field(description="Role of the agent")
|
||||||
goal: str = Field(description="Goal of the agent")
|
goal: str = Field(description="Goal of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: str = Field(description="Backstory of the agent")
|
||||||
llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||||
default=None, description="Language model that will run the agent"
|
default=None, description="Language model that will run the agent"
|
||||||
)
|
)
|
||||||
tools: List[BaseTool] = Field(
|
tools: list[BaseTool] = Field(
|
||||||
default_factory=list, description="Tools at agent's disposal"
|
default_factory=list, description="Tools at agent's disposal"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,7 +130,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
max_iterations: int = Field(
|
max_iterations: int = Field(
|
||||||
default=15, description="Maximum number of iterations for tool usage"
|
default=15, description="Maximum number of iterations for tool usage"
|
||||||
)
|
)
|
||||||
max_execution_time: Optional[int] = Field(
|
max_execution_time: int | None = Field(
|
||||||
default=None, description=". Maximum execution time in seconds"
|
default=None, description=". Maximum execution time in seconds"
|
||||||
)
|
)
|
||||||
respect_context_window: bool = Field(
|
respect_context_window: bool = Field(
|
||||||
@@ -152,25 +141,25 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
default=True,
|
default=True,
|
||||||
description="Whether to use stop words to prevent the LLM from using tools",
|
description="Whether to use stop words to prevent the LLM from using tools",
|
||||||
)
|
)
|
||||||
request_within_rpm_limit: Optional[Callable[[], bool]] = Field(
|
request_within_rpm_limit: Callable[[], bool] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Callback to check if the request is within the RPM limit",
|
description="Callback to check if the request is within the RPM limit",
|
||||||
)
|
)
|
||||||
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
|
||||||
|
|
||||||
# Output and Formatting Properties
|
# Output and Formatting Properties
|
||||||
response_format: Optional[Type[BaseModel]] = Field(
|
response_format: type[BaseModel] | None = Field(
|
||||||
default=None, description="Pydantic model for structured output"
|
default=None, description="Pydantic model for structured output"
|
||||||
)
|
)
|
||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default=False, description="Whether to print execution details"
|
default=False, description="Whether to print execution details"
|
||||||
)
|
)
|
||||||
callbacks: List[Callable] = Field(
|
callbacks: list[Callable] = Field(
|
||||||
default=[], description="Callbacks to be used for the agent"
|
default=[], description="Callbacks to be used for the agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Guardrail Properties
|
# Guardrail Properties
|
||||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = (
|
guardrail: Callable[[LiteAgentOutput], tuple[bool, Any]] | str | None = (
|
||||||
Field(
|
Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Function or string description of a guardrail to validate agent output",
|
description="Function or string description of a guardrail to validate agent output",
|
||||||
@@ -181,23 +170,23 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# State and Results
|
# State and Results
|
||||||
tools_results: List[Dict[str, Any]] = Field(
|
tools_results: list[dict[str, Any]] = Field(
|
||||||
default=[], description="Results of the tools used by the agent."
|
default=[], description="Results of the tools used by the agent."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference of Agent
|
# Reference of Agent
|
||||||
original_agent: Optional[BaseAgent] = Field(
|
original_agent: BaseAgent | None = Field(
|
||||||
default=None, description="Reference to the agent that created this LiteAgent"
|
default=None, description="Reference to the agent that created this LiteAgent"
|
||||||
)
|
)
|
||||||
# Private Attributes
|
# Private Attributes
|
||||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
_parsed_tools: list[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||||
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
_key: str = PrivateAttr(default_factory=lambda: str(uuid.uuid4()))
|
||||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
_messages: list[dict[str, str]] = PrivateAttr(default_factory=list)
|
||||||
_iterations: int = PrivateAttr(default=0)
|
_iterations: int = PrivateAttr(default=0)
|
||||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
_guardrail: Callable | None = PrivateAttr(default=None)
|
||||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -241,8 +230,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
@field_validator("guardrail", mode="before")
|
@field_validator("guardrail", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_guardrail_function(
|
def validate_guardrail_function(
|
||||||
cls, v: Optional[Union[Callable, str]]
|
cls, v: Callable | str | None
|
||||||
) -> Optional[Union[Callable, str]]:
|
) -> Callable | str | None:
|
||||||
"""Validate that the guardrail function has the correct signature.
|
"""Validate that the guardrail function has the correct signature.
|
||||||
|
|
||||||
If v is a callable, validate that it has the correct signature.
|
If v is a callable, validate that it has the correct signature.
|
||||||
@@ -267,7 +256,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
# Check return annotation if present
|
# Check return annotation if present
|
||||||
if sig.return_annotation is not sig.empty:
|
if sig.return_annotation is not sig.empty:
|
||||||
if sig.return_annotation == Tuple[bool, Any]:
|
if sig.return_annotation == tuple[bool, Any]:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
origin = get_origin(sig.return_annotation)
|
origin = get_origin(sig.return_annotation)
|
||||||
@@ -290,7 +279,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
"""Return the original role for compatibility with tool interfaces."""
|
"""Return the original role for compatibility with tool interfaces."""
|
||||||
return self.role
|
return self.role
|
||||||
|
|
||||||
def kickoff(self, messages: Union[str, List[Dict[str, str]]]) -> LiteAgentOutput:
|
def kickoff(self, messages: str | list[dict[str, str]]) -> LiteAgentOutput:
|
||||||
"""
|
"""
|
||||||
Execute the agent with the given messages.
|
Execute the agent with the given messages.
|
||||||
|
|
||||||
@@ -338,7 +327,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
def _execute_core(self, agent_info: dict[str, Any]) -> LiteAgentOutput:
|
||||||
# Emit event for agent execution start
|
# Emit event for agent execution start
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -351,16 +340,21 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
# Execute the agent using invoke loop
|
# Execute the agent using invoke loop
|
||||||
agent_finish = self._invoke_loop()
|
agent_finish = self._invoke_loop()
|
||||||
formatted_result: Optional[BaseModel] = None
|
formatted_result: BaseModel | None = None
|
||||||
if self.response_format:
|
if self.response_format:
|
||||||
try:
|
try:
|
||||||
# Cast to BaseModel to ensure type safety
|
converted_result = convert_to_model(
|
||||||
result = self.response_format.model_validate_json(agent_finish.output)
|
result=agent_finish.output,
|
||||||
if isinstance(result, BaseModel):
|
output_pydantic=self.response_format,
|
||||||
formatted_result = result
|
output_json=None,
|
||||||
|
agent=self,
|
||||||
|
converter_cls=None,
|
||||||
|
)
|
||||||
|
if isinstance(converted_result, BaseModel):
|
||||||
|
formatted_result = converted_result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
content=f"Failed to parse output into response format: {str(e)}",
|
content=f"Failed to parse output into response format: {e!s}",
|
||||||
color="yellow",
|
color="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -428,7 +422,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
async def kickoff_async(
|
async def kickoff_async(
|
||||||
self, messages: Union[str, List[Dict[str, str]]]
|
self, messages: str | list[dict[str, str]]
|
||||||
) -> LiteAgentOutput:
|
) -> LiteAgentOutput:
|
||||||
"""
|
"""
|
||||||
Execute the agent asynchronously with the given messages.
|
Execute the agent asynchronously with the given messages.
|
||||||
@@ -475,8 +469,8 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
return base_prompt
|
return base_prompt
|
||||||
|
|
||||||
def _format_messages(
|
def _format_messages(
|
||||||
self, messages: Union[str, List[Dict[str, str]]]
|
self, messages: str | list[dict[str, str]]
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Format messages for the LLM."""
|
"""Format messages for the LLM."""
|
||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
@@ -571,18 +565,18 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
i18n=self.i18n,
|
i18n=self.i18n,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
handle_unknown_error(self._printer, e)
|
||||||
handle_unknown_error(self._printer, e)
|
raise e
|
||||||
raise e
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._iterations += 1
|
self._iterations += 1
|
||||||
|
|
||||||
assert isinstance(formatted_answer, AgentFinish)
|
if not isinstance(formatted_answer, AgentFinish):
|
||||||
|
raise ValueError(f"Expected AgentFinish, got {type(formatted_answer)}")
|
||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
def _show_logs(self, formatted_answer: AgentAction | AgentFinish):
|
||||||
"""Show logs for the agent's execution."""
|
"""Show logs for the agent's execution."""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -596,3 +590,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||||
"""Append a message to the message list with the given role."""
|
"""Append a message to the message list with the given role."""
|
||||||
self._messages.append(format_message_for_llm(text, role=role))
|
self._messages.append(format_message_for_llm(text, role=role))
|
||||||
|
|
||||||
|
def get_output_converter(self, llm, model, instructions):
|
||||||
|
"""Get the converter class for the agent to create json/pydantic outputs."""
|
||||||
|
from crewai.utilities.converter import Converter
|
||||||
|
return Converter(
|
||||||
|
text="",
|
||||||
|
llm=llm,
|
||||||
|
model=model,
|
||||||
|
instructions=instructions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai import LLM, Agent
|
from crewai import LLM, Agent
|
||||||
from crewai.flow import Flow, start
|
|
||||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
|
||||||
from crewai.tools import BaseTool
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
|
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
|
||||||
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
||||||
|
from crewai.flow import Flow, start
|
||||||
|
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from unittest.mock import patch
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
# A simple test tool
|
# A simple test tool
|
||||||
@@ -37,10 +36,9 @@ class WebSearchTool(BaseTool):
|
|||||||
# This is a mock implementation
|
# This is a mock implementation
|
||||||
if "tokyo" in query.lower():
|
if "tokyo" in query.lower():
|
||||||
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
||||||
elif "climate change" in query.lower() and "coral" in query.lower():
|
if "climate change" in query.lower() and "coral" in query.lower():
|
||||||
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
|
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
|
||||||
else:
|
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
|
||||||
|
|
||||||
|
|
||||||
# Define Mock Calculator Tool
|
# Define Mock Calculator Tool
|
||||||
@@ -52,11 +50,12 @@ class CalculatorTool(BaseTool):
|
|||||||
|
|
||||||
def _run(self, expression: str) -> str:
|
def _run(self, expression: str) -> str:
|
||||||
"""Calculate the result of a mathematical expression."""
|
"""Calculate the result of a mathematical expression."""
|
||||||
|
import ast
|
||||||
try:
|
try:
|
||||||
result = eval(expression, {"__builtins__": {}})
|
result = ast.literal_eval(expression)
|
||||||
return f"The result of {expression} is {result}"
|
return f"The result of {expression} is {result}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error calculating {expression}: {str(e)}"
|
return f"Error calculating {expression}: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
# Define a custom response format using Pydantic
|
# Define a custom response format using Pydantic
|
||||||
@@ -520,6 +519,53 @@ def test_lite_agent_with_custom_llm_and_guardrails():
|
|||||||
assert result2.raw == "Modified by guardrail"
|
assert result2.raw == "Modified by guardrail"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lite_agent_structured_output_with_malformed_json():
|
||||||
|
"""Test that LiteAgent can handle malformed JSON wrapped in markdown blocks."""
|
||||||
|
|
||||||
|
class FounderNames(BaseModel):
|
||||||
|
names: list[str] = Field(description="List of founder names")
|
||||||
|
|
||||||
|
class MockLLMWithMalformedJSON(BaseLLM):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(model="mock-model")
|
||||||
|
|
||||||
|
def call(self, messages, **kwargs):
|
||||||
|
return '''Thought: I need to extract the founder names
|
||||||
|
Final Answer: ```json
|
||||||
|
{
|
||||||
|
"names": ["John Smith", "Jane Doe"]
|
||||||
|
}
|
||||||
|
```'''
|
||||||
|
|
||||||
|
def supports_function_calling(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_stop_words(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_context_window_size(self):
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
mock_llm = MockLLMWithMalformedJSON()
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Data Extraction Specialist",
|
||||||
|
goal="Extract founder names from text",
|
||||||
|
backstory="You extract and structure information accurately.",
|
||||||
|
llm=mock_llm,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent.kickoff(
|
||||||
|
messages="Extract founder names from: 'The company was founded by John Smith and Jane Doe.'",
|
||||||
|
response_format=FounderNames
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.pydantic is not None, "Should successfully parse malformed JSON"
|
||||||
|
assert isinstance(result.pydantic, FounderNames), "Should return correct Pydantic model"
|
||||||
|
assert result.pydantic.names == ["John Smith", "Jane Doe"], "Should extract correct founder names"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_lite_agent_with_invalid_llm():
|
def test_lite_agent_with_invalid_llm():
|
||||||
"""Test that LiteAgent raises proper error when create_llm returns None."""
|
"""Test that LiteAgent raises proper error when create_llm returns None."""
|
||||||
|
|||||||
Reference in New Issue
Block a user