mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
5 Commits
devin/1756
...
devin/1750
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76fea442d4 | ||
|
|
b1ae914cd3 | ||
|
|
14629bb87a | ||
|
|
a1ebdb125b | ||
|
|
39ea952acd |
166
examples/llm_generations_example.py
Normal file
166
examples/llm_generations_example.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Example demonstrating the new LLM generations and logprobs functionality.
|
||||
"""
|
||||
|
||||
from crewai import Agent, Task, LLM
|
||||
from crewai.utilities.xml_parser import extract_xml_content
|
||||
|
||||
|
||||
def example_multiple_generations():
|
||||
"""Example of using multiple generations with an agent."""
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
n=3, # Request 3 generations
|
||||
temperature=0.8, # Higher temperature for variety
|
||||
return_full_completion=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Creative Writer",
|
||||
goal="Write engaging content",
|
||||
backstory="You are a creative writer who generates multiple ideas",
|
||||
llm=llm,
|
||||
return_completion_metadata=True
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Write a short story opening about a mysterious door",
|
||||
agent=agent,
|
||||
expected_output="A compelling story opening"
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
print("Primary result:")
|
||||
print(result)
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
if hasattr(task, 'output') and task.output.completion_metadata:
|
||||
generations = task.output.get_generations()
|
||||
if generations:
|
||||
print(f"Generated {len(generations)} alternatives:")
|
||||
for i, generation in enumerate(generations, 1):
|
||||
print(f"\nGeneration {i}:")
|
||||
print(generation)
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def example_xml_extraction():
|
||||
"""Example of extracting structured content from agent output."""
|
||||
|
||||
agent = Agent(
|
||||
role="Problem Solver",
|
||||
goal="Solve problems systematically",
|
||||
backstory="You think step by step and show your reasoning",
|
||||
llm=LLM(model="gpt-3.5-turbo")
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="""
|
||||
Solve this problem: How can we reduce energy consumption in an office building?
|
||||
|
||||
Please structure your response with:
|
||||
- <thinking> tags for your internal reasoning
|
||||
- <analysis> tags for your analysis of the problem
|
||||
- <solution> tags for your proposed solution
|
||||
""",
|
||||
agent=agent,
|
||||
expected_output="A structured solution with reasoning"
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
print("Full agent output:")
|
||||
print(result)
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
thinking = extract_xml_content(result, "thinking")
|
||||
analysis = extract_xml_content(result, "analysis")
|
||||
solution = extract_xml_content(result, "solution")
|
||||
|
||||
if thinking:
|
||||
print("Agent's thinking process:")
|
||||
print(thinking)
|
||||
print("\n" + "-"*30 + "\n")
|
||||
|
||||
if analysis:
|
||||
print("Problem analysis:")
|
||||
print(analysis)
|
||||
print("\n" + "-"*30 + "\n")
|
||||
|
||||
if solution:
|
||||
print("Proposed solution:")
|
||||
print(solution)
|
||||
|
||||
|
||||
def example_logprobs_analysis():
|
||||
"""Example of accessing log probabilities for analysis."""
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
logprobs=5, # Request top 5 log probabilities
|
||||
top_logprobs=3, # Show top 3 alternatives
|
||||
return_full_completion=True
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Decision Analyst",
|
||||
goal="Make confident decisions",
|
||||
backstory="You analyze confidence levels in your responses",
|
||||
llm=llm,
|
||||
return_completion_metadata=True
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Should we invest in renewable energy? Give a yes/no answer with confidence.",
|
||||
agent=agent,
|
||||
expected_output="A clear yes/no decision"
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
print("Decision:")
|
||||
print(result)
|
||||
print("\n" + "="*50 + "\n")
|
||||
|
||||
if hasattr(task, 'output') and task.output.completion_metadata:
|
||||
logprobs = task.output.get_logprobs()
|
||||
usage = task.output.get_usage_metrics()
|
||||
|
||||
if logprobs:
|
||||
print("Confidence analysis (log probabilities):")
|
||||
print(f"Available logprobs data: {len(logprobs)} choices")
|
||||
|
||||
if usage:
|
||||
print("\nToken usage:")
|
||||
print(f"Prompt tokens: {usage.get('prompt_tokens', 'N/A')}")
|
||||
print(f"Completion tokens: {usage.get('completion_tokens', 'N/A')}")
|
||||
print(f"Total tokens: {usage.get('total_tokens', 'N/A')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== CrewAI LLM Generations and XML Extraction Examples ===\n")
|
||||
|
||||
print("1. Multiple Generations Example:")
|
||||
print("-" * 40)
|
||||
try:
|
||||
example_multiple_generations()
|
||||
except Exception as e:
|
||||
print(f"Example requires actual LLM API access: {e}")
|
||||
|
||||
print("\n\n2. XML Content Extraction Example:")
|
||||
print("-" * 40)
|
||||
try:
|
||||
example_xml_extraction()
|
||||
except Exception as e:
|
||||
print(f"Example requires actual LLM API access: {e}")
|
||||
|
||||
print("\n\n3. Log Probabilities Analysis Example:")
|
||||
print("-" * 40)
|
||||
try:
|
||||
example_logprobs_analysis()
|
||||
except Exception as e:
|
||||
print(f"Example requires actual LLM API access: {e}")
|
||||
|
||||
print("\n=== Examples completed ===")
|
||||
@@ -88,6 +88,22 @@ class Agent(BaseAgent):
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
llm_n: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of generations to request from the LLM.",
|
||||
)
|
||||
llm_logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of log probabilities to return from the LLM.",
|
||||
)
|
||||
llm_top_logprobs: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of top log probabilities to return from the LLM.",
|
||||
)
|
||||
return_completion_metadata: bool = Field(
|
||||
default=False,
|
||||
description="Whether to return full completion metadata including generations and logprobs.",
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
@@ -179,6 +195,15 @@ class Agent(BaseAgent):
|
||||
):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
|
||||
if hasattr(self.llm, 'n') and self.llm_n is not None:
|
||||
self.llm.n = self.llm_n
|
||||
if hasattr(self.llm, 'logprobs') and self.llm_logprobs is not None:
|
||||
self.llm.logprobs = self.llm_logprobs
|
||||
if hasattr(self.llm, 'top_logprobs') and self.llm_top_logprobs is not None:
|
||||
self.llm.top_logprobs = self.llm_top_logprobs
|
||||
if hasattr(self.llm, 'return_full_completion'):
|
||||
self.llm.return_full_completion = self.return_completion_metadata
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ def handle_user_input(
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
messages.append({"role": "assistant", "content": final_response})
|
||||
messages.append({"role": "assistant", "content": str(final_response)})
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ class LiteAgentOutput(BaseModel):
|
||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||
description="Token usage metrics for this execution", default=None
|
||||
)
|
||||
completion_metadata: Optional[Dict[str, Any]] = Field(
|
||||
description="Full completion metadata including generations and logprobs", default=None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert pydantic_output to a dictionary."""
|
||||
@@ -99,6 +102,40 @@ class LiteAgentOutput(BaseModel):
|
||||
return self.pydantic.model_dump()
|
||||
return {}
|
||||
|
||||
def get_generations(self) -> Optional[List[str]]:
|
||||
"""Get all generations from completion metadata."""
|
||||
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||
return None
|
||||
|
||||
generations = []
|
||||
for choice in self.completion_metadata["choices"]:
|
||||
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||
generations.append(choice.message.content or "")
|
||||
elif isinstance(choice, dict) and "message" in choice:
|
||||
generations.append(choice["message"].get("content", ""))
|
||||
|
||||
return generations if generations else None
|
||||
|
||||
def get_logprobs(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Get log probabilities from completion metadata."""
|
||||
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||
return None
|
||||
|
||||
logprobs_list = []
|
||||
for choice in self.completion_metadata["choices"]:
|
||||
if hasattr(choice, "logprobs") and choice.logprobs:
|
||||
logprobs_list.append(choice.logprobs)
|
||||
elif isinstance(choice, dict) and "logprobs" in choice:
|
||||
logprobs_list.append(choice["logprobs"])
|
||||
|
||||
return logprobs_list if logprobs_list else None
|
||||
|
||||
def get_usage_metrics_from_completion(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get token usage metrics from completion metadata."""
|
||||
if not self.completion_metadata:
|
||||
return None
|
||||
return self.completion_metadata.get("usage")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the output."""
|
||||
if self.pydantic:
|
||||
|
||||
@@ -311,6 +311,7 @@ class LLM(BaseLLM):
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
return_full_completion: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -337,6 +338,7 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.return_full_completion = return_full_completion
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -419,16 +421,18 @@ class LLM(BaseLLM):
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
return_full_completion: bool = False,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
return_full_completion: Whether to return full completion object
|
||||
|
||||
Returns:
|
||||
str: The complete response text
|
||||
Union[str, Dict[str, Any]]: The complete response text or full completion object
|
||||
|
||||
Raises:
|
||||
Exception: If no content is received from the streaming response
|
||||
@@ -626,11 +630,46 @@ class LLM(BaseLLM):
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
# Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
|
||||
if return_full_completion:
|
||||
accumulated_choices = []
|
||||
if last_chunk and hasattr(last_chunk, "choices"):
|
||||
for choice in last_chunk.choices:
|
||||
accumulated_choices.append({
|
||||
"message": {"content": full_response},
|
||||
"finish_reason": getattr(choice, "finish_reason", None),
|
||||
"index": getattr(choice, "index", 0),
|
||||
})
|
||||
else:
|
||||
accumulated_choices = [{"message": {"content": full_response}, "finish_reason": "stop", "index": 0}]
|
||||
|
||||
return {
|
||||
"content": full_response,
|
||||
"choices": accumulated_choices,
|
||||
"usage": usage_info,
|
||||
"model": params.get("model"),
|
||||
"created": getattr(last_chunk, "created", None) if last_chunk else None,
|
||||
"id": getattr(last_chunk, "id", None) if last_chunk else None,
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None,
|
||||
}
|
||||
|
||||
return full_response
|
||||
|
||||
# --- 9) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
if tool_result is not None:
|
||||
if return_full_completion:
|
||||
return {
|
||||
"content": tool_result,
|
||||
"choices": [{"message": {"content": tool_result}}],
|
||||
"usage": usage_info,
|
||||
"model": params.get("model"),
|
||||
"created": getattr(last_chunk, "created", None) if last_chunk else None,
|
||||
"id": getattr(last_chunk, "id", None) if last_chunk else None,
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None,
|
||||
}
|
||||
return tool_result
|
||||
|
||||
# --- 10) Log token usage if available in streaming mode
|
||||
@@ -638,6 +677,30 @@ class LLM(BaseLLM):
|
||||
|
||||
# --- 11) Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
|
||||
if return_full_completion:
|
||||
accumulated_choices = []
|
||||
if last_chunk and hasattr(last_chunk, "choices"):
|
||||
for choice in last_chunk.choices:
|
||||
accumulated_choices.append({
|
||||
"message": {"content": full_response},
|
||||
"finish_reason": getattr(choice, "finish_reason", None),
|
||||
"index": getattr(choice, "index", 0),
|
||||
})
|
||||
else:
|
||||
accumulated_choices = [{"message": {"content": full_response}, "finish_reason": "stop", "index": 0}]
|
||||
|
||||
return {
|
||||
"content": full_response,
|
||||
"choices": accumulated_choices,
|
||||
"usage": usage_info,
|
||||
"model": params.get("model"),
|
||||
"created": getattr(last_chunk, "created", None) if last_chunk else None,
|
||||
"id": getattr(last_chunk, "id", None) if last_chunk else None,
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None,
|
||||
}
|
||||
|
||||
return full_response
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
@@ -748,16 +811,18 @@ class LLM(BaseLLM):
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
return_full_completion: bool = False,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
Args:
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
return_full_completion: Whether to return full completion object
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
Union[str, Dict[str, Any]]: The response text or full completion object
|
||||
"""
|
||||
# --- 1) Make the completion call
|
||||
try:
|
||||
@@ -793,18 +858,51 @@ class LLM(BaseLLM):
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# --- 5) If no tool calls or no available functions, return the text response directly
|
||||
# --- 5) If no tool calls or no available functions, return the response
|
||||
if not tool_calls or not available_functions:
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
if return_full_completion:
|
||||
return {
|
||||
"content": text_response,
|
||||
"choices": response.choices,
|
||||
"usage": getattr(response, "usage", None),
|
||||
"model": response.model,
|
||||
"created": getattr(response, "created", None),
|
||||
"id": getattr(response, "id", None),
|
||||
"object": getattr(response, "object", "chat.completion"),
|
||||
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||
}
|
||||
return text_response
|
||||
|
||||
# --- 6) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
if tool_result is not None:
|
||||
if return_full_completion:
|
||||
return {
|
||||
"content": tool_result,
|
||||
"choices": response.choices,
|
||||
"usage": getattr(response, "usage", None),
|
||||
"model": response.model,
|
||||
"created": getattr(response, "created", None),
|
||||
"id": getattr(response, "id", None),
|
||||
"object": getattr(response, "object", "chat.completion"),
|
||||
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||
}
|
||||
return tool_result
|
||||
|
||||
# --- 7) If tool call handling didn't return a result, emit completion event and return text response
|
||||
# --- 7) If tool call handling didn't return a result, emit completion event and return response
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
if return_full_completion:
|
||||
return {
|
||||
"content": text_response,
|
||||
"choices": response.choices,
|
||||
"usage": getattr(response, "usage", None),
|
||||
"model": response.model,
|
||||
"created": getattr(response, "created", None),
|
||||
"id": getattr(response, "id", None),
|
||||
"object": getattr(response, "object", "chat.completion"),
|
||||
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||
}
|
||||
return text_response
|
||||
|
||||
def _handle_tool_call(
|
||||
@@ -889,7 +987,8 @@ class LLM(BaseLLM):
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
return_full_completion: Optional[bool] = None,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -903,10 +1002,11 @@ class LLM(BaseLLM):
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
return_full_completion: Optional override for returning full completion object
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
Union[str, Dict[str, Any]]: Either a text response from the LLM (str) or
|
||||
the full completion object with generations and metadata.
|
||||
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
@@ -944,17 +1044,20 @@ class LLM(BaseLLM):
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
# --- 6) Determine if we should return full completion
|
||||
should_return_full = return_full_completion if return_full_completion is not None else self.return_full_completion
|
||||
|
||||
# --- 7) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
# --- 7) Make the completion call and handle response
|
||||
# --- 8) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
params, callbacks, available_functions, should_return_full
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
params, callbacks, available_functions, should_return_full
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -26,6 +26,9 @@ class TaskOutput(BaseModel):
|
||||
output_format: OutputFormat = Field(
|
||||
description="Output format of the task", default=OutputFormat.RAW
|
||||
)
|
||||
completion_metadata: Optional[Dict[str, Any]] = Field(
|
||||
description="Full completion metadata including generations and logprobs", default=None
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_summary(self):
|
||||
@@ -56,6 +59,40 @@ class TaskOutput(BaseModel):
|
||||
output_dict.update(self.pydantic.model_dump())
|
||||
return output_dict
|
||||
|
||||
def get_generations(self) -> Optional[List[str]]:
|
||||
"""Get all generations from completion metadata."""
|
||||
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||
return None
|
||||
|
||||
generations = []
|
||||
for choice in self.completion_metadata["choices"]:
|
||||
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||
generations.append(choice.message.content or "")
|
||||
elif isinstance(choice, dict) and "message" in choice:
|
||||
generations.append(choice["message"].get("content", ""))
|
||||
|
||||
return generations if generations else None
|
||||
|
||||
def get_logprobs(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Get log probabilities from completion metadata."""
|
||||
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||
return None
|
||||
|
||||
logprobs_list = []
|
||||
for choice in self.completion_metadata["choices"]:
|
||||
if hasattr(choice, "logprobs") and choice.logprobs:
|
||||
logprobs_list.append(choice.logprobs)
|
||||
elif isinstance(choice, dict) and "logprobs" in choice:
|
||||
logprobs_list.append(choice["logprobs"])
|
||||
|
||||
return logprobs_list if logprobs_list else None
|
||||
|
||||
def get_usage_metrics(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get token usage metrics from completion metadata."""
|
||||
if not self.completion_metadata:
|
||||
return None
|
||||
return self.completion_metadata.get("usage")
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.pydantic:
|
||||
return str(self.pydantic)
|
||||
|
||||
@@ -109,7 +109,11 @@ def handle_max_iterations_exceeded(
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
formatted_answer = format_answer(answer)
|
||||
text_answer = answer
|
||||
if isinstance(answer, dict) and "content" in answer:
|
||||
text_answer = answer["content"]
|
||||
|
||||
formatted_answer = format_answer(str(text_answer))
|
||||
# Return the formatted answer, regardless of its type
|
||||
return formatted_answer
|
||||
|
||||
@@ -145,42 +149,64 @@ def get_llm_response(
|
||||
messages: List[Dict[str, str]],
|
||||
callbacks: List[Any],
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
return_full_completion: bool = False,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Call the LLM and return the response, handling any invalid responses."""
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
if isinstance(llm, LLM) and return_full_completion:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
return_full_completion=return_full_completion,
|
||||
)
|
||||
else:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error during LLM call: {e}",
|
||||
color="red",
|
||||
)
|
||||
raise e
|
||||
if not answer:
|
||||
printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
return answer
|
||||
|
||||
if return_full_completion:
|
||||
if not answer or (isinstance(answer, dict) and not answer.get("content")):
|
||||
printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
return answer
|
||||
else:
|
||||
if not answer:
|
||||
printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
return answer
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
answer: Union[str, Dict[str, Any]], use_stop_words: bool
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
||||
text_answer = answer
|
||||
if isinstance(answer, dict):
|
||||
text_answer = answer.get("content", "")
|
||||
|
||||
if not use_stop_words:
|
||||
try:
|
||||
# Preliminary parsing to check for errors.
|
||||
format_answer(answer)
|
||||
format_answer(str(text_answer))
|
||||
except OutputParserException as e:
|
||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||
answer = answer.split("Observation:")[0].strip()
|
||||
text_answer = str(text_answer).split("Observation:")[0].strip()
|
||||
|
||||
return format_answer(answer)
|
||||
return format_answer(str(text_answer))
|
||||
|
||||
|
||||
def handle_agent_action_core(
|
||||
|
||||
@@ -260,14 +260,22 @@ class AgentReasoning:
|
||||
available_functions={"create_reasoning_plan": _create_reasoning_plan},
|
||||
)
|
||||
|
||||
self.logger.debug(f"Function calling response: {response[:100]}...")
|
||||
|
||||
try:
|
||||
result = json.loads(response)
|
||||
if "plan" in result and "ready" in result:
|
||||
return result["plan"], result["ready"]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
if isinstance(response, dict):
|
||||
response_str = str(response)
|
||||
self.logger.debug(f"Function calling response: {response_str[:100]}...")
|
||||
|
||||
if "plan" in response and "ready" in response:
|
||||
return response["plan"], response["ready"]
|
||||
else:
|
||||
response_str = str(response)
|
||||
self.logger.debug(f"Function calling response: {response_str[:100]}...")
|
||||
|
||||
try:
|
||||
result = json.loads(response_str)
|
||||
if "plan" in result and "ready" in result:
|
||||
return result["plan"], result["ready"]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
response_str = str(response)
|
||||
return response_str, "READY: I am ready to execute the task." in response_str
|
||||
|
||||
133
src/crewai/utilities/xml_parser.py
Normal file
133
src/crewai/utilities/xml_parser.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
def extract_xml_content(text: str, tag: str) -> Optional[str]:
|
||||
"""
|
||||
Extract content from a specific XML tag.
|
||||
|
||||
Args:
|
||||
text: The text to search in
|
||||
tag: The XML tag name to extract (without angle brackets)
|
||||
|
||||
Returns:
|
||||
The content inside the first occurrence of the tag, or None if not found
|
||||
"""
|
||||
pattern = rf'<{re.escape(tag)}>(.*?)</{re.escape(tag)}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def extract_all_xml_content(text: str, tag: str) -> List[str]:
|
||||
"""
|
||||
Extract content from all occurrences of a specific XML tag.
|
||||
|
||||
Args:
|
||||
text: The text to search in
|
||||
tag: The XML tag name to extract (without angle brackets)
|
||||
|
||||
Returns:
|
||||
List of content strings from all occurrences of the tag
|
||||
"""
|
||||
pattern = rf'<{re.escape(tag)}>(.*?)</{re.escape(tag)}>'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
return [match.strip() for match in matches]
|
||||
|
||||
|
||||
def extract_multiple_xml_tags(text: str, tags: List[str]) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Extract content from multiple XML tags.
|
||||
|
||||
Args:
|
||||
text: The text to search in
|
||||
tags: List of XML tag names to extract (without angle brackets)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tag names to their content (or None if not found)
|
||||
"""
|
||||
result = {}
|
||||
for tag in tags:
|
||||
result[tag] = extract_xml_content(text, tag)
|
||||
return result
|
||||
|
||||
|
||||
def extract_multiple_xml_tags_all(text: str, tags: List[str]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Extract content from all occurrences of multiple XML tags.
|
||||
|
||||
Args:
|
||||
text: The text to search in
|
||||
tags: List of XML tag names to extract (without angle brackets)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tag names to lists of their content
|
||||
"""
|
||||
result = {}
|
||||
for tag in tags:
|
||||
result[tag] = extract_all_xml_content(text, tag)
|
||||
return result
|
||||
|
||||
|
||||
def extract_xml_with_attributes(text: str, tag: str) -> List[Dict[str, Union[str, Dict[str, str]]]]:
|
||||
"""
|
||||
Extract XML tags with their attributes and content.
|
||||
|
||||
Args:
|
||||
text: The text to search in
|
||||
tag: The XML tag name to extract (without angle brackets)
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing 'content' and 'attributes' for each occurrence
|
||||
"""
|
||||
pattern = rf'<{re.escape(tag)}([^>]*)>(.*?)</{re.escape(tag)}>'
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
||||
result = []
|
||||
for attrs_str, content in matches:
|
||||
attributes = {}
|
||||
if attrs_str.strip():
|
||||
attr_pattern = r'(\w+)=["\']([^"\']*)["\']'
|
||||
attributes = dict(re.findall(attr_pattern, attrs_str))
|
||||
|
||||
result.append({
|
||||
'content': content.strip(),
|
||||
'attributes': attributes
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_xml_tags(text: str, tags: List[str]) -> str:
|
||||
"""
|
||||
Remove specific XML tags and their content from text.
|
||||
|
||||
Args:
|
||||
text: The text to process
|
||||
tags: List of XML tag names to remove (without angle brackets)
|
||||
|
||||
Returns:
|
||||
Text with the specified XML tags and their content removed
|
||||
"""
|
||||
result = text
|
||||
for tag in tags:
|
||||
pattern = rf'<{re.escape(tag)}[^>]*>.*?</{re.escape(tag)}>'
|
||||
result = re.sub(pattern, '', result, flags=re.DOTALL)
|
||||
return result.strip()
|
||||
|
||||
|
||||
def strip_xml_tags_keep_content(text: str, tags: List[str]) -> str:
|
||||
"""
|
||||
Remove specific XML tags but keep their content.
|
||||
|
||||
Args:
|
||||
text: The text to process
|
||||
tags: List of XML tag names to strip (without angle brackets)
|
||||
|
||||
Returns:
|
||||
Text with the specified XML tags removed but content preserved
|
||||
"""
|
||||
result = text
|
||||
for tag in tags:
|
||||
pattern = rf'<{re.escape(tag)}[^>]*>(.*?)</{re.escape(tag)}>'
|
||||
result = re.sub(pattern, r'\1', result, flags=re.DOTALL)
|
||||
return result.strip()
|
||||
@@ -162,6 +162,7 @@ def test_task_callback_returns_task_output():
|
||||
"name": None,
|
||||
"expected_output": "Bullet point list of 5 interesting ideas.",
|
||||
"output_format": OutputFormat.RAW,
|
||||
"completion_metadata": None,
|
||||
}
|
||||
assert output_dict == expected_output
|
||||
|
||||
|
||||
159
tests/test_integration_llm_features.py
Normal file
159
tests/test_integration_llm_features.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai import Agent, Task, Crew, LLM
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.utilities.xml_parser import extract_xml_content
|
||||
|
||||
|
||||
class TestIntegrationLLMFeatures:
|
||||
"""Integration tests for LLM features with agents and tasks."""
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_agent_with_multiple_generations(self, mock_completion):
|
||||
"""Test agent execution with multiple generations."""
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="Generation 1")),
|
||||
Mock(message=Mock(content="Generation 2")),
|
||||
Mock(message=Mock(content="Generation 3")),
|
||||
]
|
||||
mock_response.usage = {"prompt_tokens": 20, "completion_tokens": 30}
|
||||
mock_response.model = "gpt-3.5-turbo"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.id = "test-id"
|
||||
mock_response.object = "chat.completion"
|
||||
mock_response.system_fingerprint = "test-fingerprint"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm = LLM(model="gpt-3.5-turbo", n=3, return_full_completion=True)
|
||||
agent = Agent(
|
||||
role="writer",
|
||||
goal="write content",
|
||||
backstory="You are a writer",
|
||||
llm=llm,
|
||||
return_completion_metadata=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Write a short story",
|
||||
agent=agent,
|
||||
expected_output="A short story",
|
||||
)
|
||||
|
||||
with patch.object(agent, 'agent_executor') as mock_executor:
|
||||
mock_executor.invoke.return_value = {"output": "Generation 1"}
|
||||
|
||||
result = agent.execute_task(task)
|
||||
assert result == "Generation 1"
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_lite_agent_with_xml_extraction(self, mock_completion):
|
||||
"""Test LiteAgent with XML content extraction."""
|
||||
response_with_xml = """
|
||||
<thinking>
|
||||
I need to analyze this problem step by step.
|
||||
First, I'll consider the requirements.
|
||||
</thinking>
|
||||
|
||||
Based on my analysis, here's the solution: The answer is 42.
|
||||
"""
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content=response_with_xml))]
|
||||
mock_response.usage = {"prompt_tokens": 15, "completion_tokens": 25}
|
||||
mock_response.model = "gpt-3.5-turbo"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.id = "test-id"
|
||||
mock_response.object = "chat.completion"
|
||||
mock_response.system_fingerprint = "test-fingerprint"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
role="analyst",
|
||||
goal="analyze problems",
|
||||
backstory="You are an analyst",
|
||||
llm=LLM(model="gpt-3.5-turbo", return_full_completion=True),
|
||||
)
|
||||
|
||||
with patch.object(lite_agent, '_invoke_loop') as mock_invoke:
|
||||
from crewai.agents.agent_action import AgentFinish
|
||||
mock_agent_finish = AgentFinish(output=response_with_xml, text=response_with_xml)
|
||||
mock_invoke.return_value = mock_agent_finish
|
||||
|
||||
result = lite_agent.kickoff("Analyze this problem")
|
||||
|
||||
thinking_content = extract_xml_content(result.raw, "thinking")
|
||||
assert thinking_content is not None
|
||||
assert "step by step" in thinking_content
|
||||
assert "requirements" in thinking_content
|
||||
|
||||
def test_xml_parser_with_complex_agent_output(self):
|
||||
"""Test XML parser with complex agent output containing multiple tags."""
|
||||
complex_output = """
|
||||
<thinking>
|
||||
This is a complex problem that requires careful analysis.
|
||||
I need to break it down into steps.
|
||||
</thinking>
|
||||
|
||||
<reasoning>
|
||||
Step 1: Understand the requirements
|
||||
Step 2: Analyze the constraints
|
||||
Step 3: Develop a solution
|
||||
</reasoning>
|
||||
|
||||
<conclusion>
|
||||
The best approach is to use a systematic methodology.
|
||||
</conclusion>
|
||||
|
||||
Final answer: Use the systematic approach outlined above.
|
||||
"""
|
||||
|
||||
thinking = extract_xml_content(complex_output, "thinking")
|
||||
reasoning = extract_xml_content(complex_output, "reasoning")
|
||||
conclusion = extract_xml_content(complex_output, "conclusion")
|
||||
|
||||
assert thinking is not None
|
||||
assert "complex problem" in thinking
|
||||
assert reasoning is not None
|
||||
assert "Step 1" in reasoning
|
||||
assert "Step 2" in reasoning
|
||||
assert "Step 3" in reasoning
|
||||
assert conclusion is not None
|
||||
assert "systematic methodology" in conclusion
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_crew_with_llm_parameters(self, mock_completion):
|
||||
"""Test crew execution with LLM parameters."""
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="Test response"))]
|
||||
mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5}
|
||||
mock_response.model = "gpt-3.5-turbo"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.id = "test-id"
|
||||
mock_response.object = "chat.completion"
|
||||
mock_response.system_fingerprint = "test-fingerprint"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
agent = Agent(
|
||||
role="analyst",
|
||||
goal="analyze data",
|
||||
backstory="You are an analyst",
|
||||
llm_n=2,
|
||||
llm_logprobs=5,
|
||||
return_completion_metadata=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Analyze the data",
|
||||
agent=agent,
|
||||
expected_output="Analysis results",
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(crew, '_run_sequential_process') as mock_run:
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
mock_output = CrewOutput(raw="Test response")
|
||||
mock_run.return_value = mock_output
|
||||
|
||||
result = crew.kickoff()
|
||||
assert result is not None
|
||||
226
tests/test_llm_generations_logprobs.py
Normal file
226
tests/test_llm_generations_logprobs.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai import Agent, LLM
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.lite_agent import LiteAgentOutput
|
||||
from crewai.utilities.xml_parser import (
|
||||
extract_xml_content,
|
||||
extract_all_xml_content,
|
||||
extract_multiple_xml_tags,
|
||||
extract_multiple_xml_tags_all,
|
||||
extract_xml_with_attributes,
|
||||
remove_xml_tags,
|
||||
strip_xml_tags_keep_content,
|
||||
)
|
||||
|
||||
|
||||
class TestLLMGenerationsLogprobs:
|
||||
"""Test suite for LLM generations and logprobs functionality."""
|
||||
|
||||
def test_llm_with_n_parameter(self):
|
||||
"""Test that LLM accepts n parameter for multiple generations."""
|
||||
llm = LLM(model="gpt-3.5-turbo", n=3)
|
||||
assert llm.n == 3
|
||||
|
||||
def test_llm_with_logprobs_parameter(self):
|
||||
"""Test that LLM accepts logprobs parameter."""
|
||||
llm = LLM(model="gpt-3.5-turbo", logprobs=5)
|
||||
assert llm.logprobs == 5
|
||||
|
||||
def test_llm_with_return_full_completion(self):
|
||||
"""Test that LLM accepts return_full_completion parameter."""
|
||||
llm = LLM(model="gpt-3.5-turbo", return_full_completion=True)
|
||||
assert llm.return_full_completion is True
|
||||
|
||||
def test_agent_with_llm_parameters(self):
|
||||
"""Test that Agent accepts LLM generation parameters."""
|
||||
agent = Agent(
|
||||
role="test",
|
||||
goal="test",
|
||||
backstory="test",
|
||||
llm_n=3,
|
||||
llm_logprobs=5,
|
||||
llm_top_logprobs=3,
|
||||
return_completion_metadata=True,
|
||||
)
|
||||
assert agent.llm_n == 3
|
||||
assert agent.llm_logprobs == 5
|
||||
assert agent.llm_top_logprobs == 3
|
||||
assert agent.return_completion_metadata is True
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_call_returns_full_completion(self, mock_completion):
|
||||
"""Test that LLM.call can return full completion object."""
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock()]
|
||||
mock_response.choices[0].message.content = "Test response"
|
||||
mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5}
|
||||
mock_response.model = "gpt-3.5-turbo"
|
||||
mock_response.created = 1234567890
|
||||
mock_response.id = "test-id"
|
||||
mock_response.object = "chat.completion"
|
||||
mock_response.system_fingerprint = "test-fingerprint"
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm = LLM(model="gpt-3.5-turbo", return_full_completion=True)
|
||||
result = llm.call("Test message")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["content"] == "Test response"
|
||||
assert "choices" in result
|
||||
assert "usage" in result
|
||||
assert result["model"] == "gpt-3.5-turbo"
|
||||
|
||||
def test_task_output_completion_metadata(self):
|
||||
"""Test TaskOutput with completion metadata."""
|
||||
mock_choices = [
|
||||
Mock(message=Mock(content="Generation 1")),
|
||||
Mock(message=Mock(content="Generation 2")),
|
||||
]
|
||||
mock_usage = {"prompt_tokens": 10, "completion_tokens": 15}
|
||||
|
||||
completion_metadata = {
|
||||
"choices": mock_choices,
|
||||
"usage": mock_usage,
|
||||
"model": "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
task_output = TaskOutput(
|
||||
description="Test task",
|
||||
raw="Generation 1",
|
||||
agent="test-agent",
|
||||
completion_metadata=completion_metadata,
|
||||
)
|
||||
|
||||
generations = task_output.get_generations()
|
||||
assert generations == ["Generation 1", "Generation 2"]
|
||||
|
||||
usage = task_output.get_usage_metrics()
|
||||
assert usage == mock_usage
|
||||
|
||||
def test_lite_agent_output_completion_metadata(self):
|
||||
"""Test LiteAgentOutput with completion metadata."""
|
||||
mock_choices = [
|
||||
Mock(message=Mock(content="Generation 1")),
|
||||
Mock(message=Mock(content="Generation 2")),
|
||||
]
|
||||
mock_usage = {"prompt_tokens": 10, "completion_tokens": 15}
|
||||
|
||||
completion_metadata = {
|
||||
"choices": mock_choices,
|
||||
"usage": mock_usage,
|
||||
"model": "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
output = LiteAgentOutput(
|
||||
raw="Generation 1",
|
||||
agent_role="test-agent",
|
||||
completion_metadata=completion_metadata,
|
||||
)
|
||||
|
||||
generations = output.get_generations()
|
||||
assert generations == ["Generation 1", "Generation 2"]
|
||||
|
||||
usage = output.get_usage_metrics_from_completion()
|
||||
assert usage == mock_usage
|
||||
|
||||
|
||||
class TestXMLParser:
|
||||
"""Test suite for XML parsing functionality."""
|
||||
|
||||
def test_extract_xml_content_basic(self):
|
||||
"""Test basic XML content extraction."""
|
||||
text = "Some text <thinking>This is my thought</thinking> more text"
|
||||
result = extract_xml_content(text, "thinking")
|
||||
assert result == "This is my thought"
|
||||
|
||||
def test_extract_xml_content_not_found(self):
|
||||
"""Test XML content extraction when tag not found."""
|
||||
text = "Some text without the tag"
|
||||
result = extract_xml_content(text, "thinking")
|
||||
assert result is None
|
||||
|
||||
def test_extract_xml_content_multiline(self):
|
||||
"""Test XML content extraction with multiline content."""
|
||||
text = """Some text
|
||||
<thinking>
|
||||
This is a multiline
|
||||
thought process
|
||||
</thinking>
|
||||
more text"""
|
||||
result = extract_xml_content(text, "thinking")
|
||||
assert "multiline" in result
|
||||
assert "thought process" in result
|
||||
|
||||
def test_extract_all_xml_content(self):
|
||||
"""Test extracting all occurrences of XML content."""
|
||||
text = """
|
||||
<thinking>First thought</thinking>
|
||||
Some text
|
||||
<thinking>Second thought</thinking>
|
||||
"""
|
||||
result = extract_all_xml_content(text, "thinking")
|
||||
assert len(result) == 2
|
||||
assert result[0] == "First thought"
|
||||
assert result[1] == "Second thought"
|
||||
|
||||
def test_extract_multiple_xml_tags(self):
|
||||
"""Test extracting multiple different XML tags."""
|
||||
text = """
|
||||
<thinking>My thoughts</thinking>
|
||||
<reasoning>My reasoning</reasoning>
|
||||
<conclusion>My conclusion</conclusion>
|
||||
"""
|
||||
result = extract_multiple_xml_tags(text, ["thinking", "reasoning", "conclusion"])
|
||||
assert result["thinking"] == "My thoughts"
|
||||
assert result["reasoning"] == "My reasoning"
|
||||
assert result["conclusion"] == "My conclusion"
|
||||
|
||||
def test_extract_multiple_xml_tags_all(self):
|
||||
"""Test extracting all occurrences of multiple XML tags."""
|
||||
text = """
|
||||
<thinking>First thought</thinking>
|
||||
<reasoning>First reasoning</reasoning>
|
||||
<thinking>Second thought</thinking>
|
||||
"""
|
||||
result = extract_multiple_xml_tags_all(text, ["thinking", "reasoning"])
|
||||
assert len(result["thinking"]) == 2
|
||||
assert len(result["reasoning"]) == 1
|
||||
assert result["thinking"][0] == "First thought"
|
||||
assert result["thinking"][1] == "Second thought"
|
||||
|
||||
def test_extract_xml_with_attributes(self):
|
||||
"""Test extracting XML with attributes."""
|
||||
text = '<thinking type="deep" level="2">Complex thought</thinking>'
|
||||
result = extract_xml_with_attributes(text, "thinking")
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Complex thought"
|
||||
assert result[0]["attributes"]["type"] == "deep"
|
||||
assert result[0]["attributes"]["level"] == "2"
|
||||
|
||||
def test_remove_xml_tags(self):
|
||||
"""Test removing XML tags and their content."""
|
||||
text = "Keep this <thinking>Remove this</thinking> and this"
|
||||
result = remove_xml_tags(text, ["thinking"])
|
||||
assert result == "Keep this and this"
|
||||
|
||||
def test_strip_xml_tags_keep_content(self):
|
||||
"""Test stripping XML tags but keeping content."""
|
||||
text = "Keep this <thinking>Keep this too</thinking> and this"
|
||||
result = strip_xml_tags_keep_content(text, ["thinking"])
|
||||
assert result == "Keep this Keep this too and this"
|
||||
|
||||
def test_nested_xml_tags(self):
|
||||
"""Test handling of nested XML tags."""
|
||||
text = "<outer>Before <inner>nested content</inner> after</outer>"
|
||||
result = extract_xml_content(text, "outer")
|
||||
assert "Before" in result
|
||||
assert "nested content" in result
|
||||
assert "after" in result
|
||||
|
||||
def test_xml_with_special_characters(self):
|
||||
"""Test XML parsing with special characters."""
|
||||
text = "<thinking>Content with & < > \" ' characters</thinking>"
|
||||
result = extract_xml_content(text, "thinking")
|
||||
assert "&" in result
|
||||
assert "<" in result
|
||||
assert ">" in result
|
||||
161
tests/test_xml_parser_examples.py
Normal file
161
tests/test_xml_parser_examples.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from crewai.utilities.xml_parser import (
|
||||
extract_xml_content,
|
||||
extract_all_xml_content,
|
||||
extract_multiple_xml_tags,
|
||||
remove_xml_tags,
|
||||
strip_xml_tags_keep_content,
|
||||
)
|
||||
|
||||
|
||||
class TestXMLParserExamples:
|
||||
"""Test XML parser with realistic agent output examples."""
|
||||
|
||||
def test_agent_thinking_extraction(self):
|
||||
"""Test extracting thinking content from agent output."""
|
||||
agent_output = """
|
||||
I need to solve this problem step by step.
|
||||
|
||||
<thinking>
|
||||
Let me break this down:
|
||||
1. First, I need to understand the requirements
|
||||
2. Then, I'll analyze the constraints
|
||||
3. Finally, I'll propose a solution
|
||||
|
||||
The key insight is that we need to balance efficiency with accuracy.
|
||||
</thinking>
|
||||
|
||||
Based on my analysis, here's my recommendation: Use approach A.
|
||||
"""
|
||||
|
||||
thinking = extract_xml_content(agent_output, "thinking")
|
||||
assert thinking is not None
|
||||
assert "break this down" in thinking
|
||||
assert "requirements" in thinking
|
||||
assert "constraints" in thinking
|
||||
assert "efficiency with accuracy" in thinking
|
||||
|
||||
def test_multiple_reasoning_tags(self):
|
||||
"""Test extracting multiple reasoning sections."""
|
||||
agent_output = """
|
||||
<reasoning>
|
||||
Initial analysis shows three possible approaches.
|
||||
</reasoning>
|
||||
|
||||
Let me explore each option:
|
||||
|
||||
<reasoning>
|
||||
Option A: Fast but less accurate
|
||||
Option B: Slow but very accurate
|
||||
Option C: Balanced approach
|
||||
</reasoning>
|
||||
|
||||
My final recommendation is Option C.
|
||||
"""
|
||||
|
||||
reasoning_sections = extract_all_xml_content(agent_output, "reasoning")
|
||||
assert len(reasoning_sections) == 2
|
||||
assert "three possible approaches" in reasoning_sections[0]
|
||||
assert "Option A" in reasoning_sections[1]
|
||||
assert "Option B" in reasoning_sections[1]
|
||||
assert "Option C" in reasoning_sections[1]
|
||||
|
||||
def test_complex_agent_workflow(self):
|
||||
"""Test complex agent output with multiple tag types."""
|
||||
complex_output = """
|
||||
<thinking>
|
||||
This is a complex problem requiring systematic analysis.
|
||||
I need to consider multiple factors.
|
||||
</thinking>
|
||||
|
||||
<analysis>
|
||||
Factor 1: Performance requirements
|
||||
Factor 2: Cost constraints
|
||||
Factor 3: Time limitations
|
||||
</analysis>
|
||||
|
||||
<reasoning>
|
||||
Given the analysis above, I believe we should prioritize performance
|
||||
while keeping costs reasonable. Time is less critical in this case.
|
||||
</reasoning>
|
||||
|
||||
<conclusion>
|
||||
Recommend Solution X with performance optimizations.
|
||||
</conclusion>
|
||||
|
||||
Final answer: Implement Solution X with the following optimizations...
|
||||
"""
|
||||
|
||||
extracted = extract_multiple_xml_tags(
|
||||
complex_output,
|
||||
["thinking", "analysis", "reasoning", "conclusion"]
|
||||
)
|
||||
|
||||
assert extracted["thinking"] is not None
|
||||
assert "systematic analysis" in extracted["thinking"]
|
||||
|
||||
assert extracted["analysis"] is not None
|
||||
assert "Factor 1" in extracted["analysis"]
|
||||
assert "Factor 2" in extracted["analysis"]
|
||||
assert "Factor 3" in extracted["analysis"]
|
||||
|
||||
assert extracted["reasoning"] is not None
|
||||
assert "prioritize performance" in extracted["reasoning"]
|
||||
|
||||
assert extracted["conclusion"] is not None
|
||||
assert "Solution X" in extracted["conclusion"]
|
||||
|
||||
def test_clean_output_for_user(self):
|
||||
"""Test cleaning agent output for user presentation."""
|
||||
raw_output = """
|
||||
<thinking>
|
||||
Internal reasoning that user shouldn't see.
|
||||
This contains implementation details.
|
||||
</thinking>
|
||||
|
||||
<debug>
|
||||
Debug information: variable X = 42
|
||||
</debug>
|
||||
|
||||
Here's the answer to your question: The solution is to use method Y.
|
||||
|
||||
<internal_notes>
|
||||
Remember to update the documentation later.
|
||||
</internal_notes>
|
||||
|
||||
This approach will give you the best results.
|
||||
"""
|
||||
|
||||
clean_output = remove_xml_tags(
|
||||
raw_output,
|
||||
["thinking", "debug", "internal_notes"]
|
||||
)
|
||||
|
||||
assert "Internal reasoning" not in clean_output
|
||||
assert "Debug information" not in clean_output
|
||||
assert "update the documentation" not in clean_output
|
||||
assert "Here's the answer" in clean_output
|
||||
assert "method Y" in clean_output
|
||||
assert "best results" in clean_output
|
||||
|
||||
def test_preserve_structured_content(self):
|
||||
"""Test preserving structured content while removing tags."""
|
||||
structured_output = """
|
||||
<steps>
|
||||
1. Initialize the system
|
||||
2. Load the configuration
|
||||
3. Process the data
|
||||
4. Generate the report
|
||||
</steps>
|
||||
|
||||
Follow these steps to complete the task.
|
||||
"""
|
||||
|
||||
clean_output = strip_xml_tags_keep_content(structured_output, ["steps"])
|
||||
|
||||
assert "<steps>" not in clean_output
|
||||
assert "</steps>" not in clean_output
|
||||
assert "1. Initialize" in clean_output
|
||||
assert "2. Load" in clean_output
|
||||
assert "3. Process" in clean_output
|
||||
assert "4. Generate" in clean_output
|
||||
assert "Follow these steps" in clean_output
|
||||
Reference in New Issue
Block a user