mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 08:38:15 +00:00
Compare commits
5 Commits
devin/1768
...
devin/1750
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76fea442d4 | ||
|
|
b1ae914cd3 | ||
|
|
14629bb87a | ||
|
|
a1ebdb125b | ||
|
|
39ea952acd |
166
examples/llm_generations_example.py
Normal file
166
examples/llm_generations_example.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Example demonstrating the new LLM generations and logprobs functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from crewai import Agent, Task, LLM
|
||||||
|
from crewai.utilities.xml_parser import extract_xml_content
|
||||||
|
|
||||||
|
|
||||||
|
def example_multiple_generations():
|
||||||
|
"""Example of using multiple generations with an agent."""
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
n=3, # Request 3 generations
|
||||||
|
temperature=0.8, # Higher temperature for variety
|
||||||
|
return_full_completion=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Creative Writer",
|
||||||
|
goal="Write engaging content",
|
||||||
|
backstory="You are a creative writer who generates multiple ideas",
|
||||||
|
llm=llm,
|
||||||
|
return_completion_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Write a short story opening about a mysterious door",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A compelling story opening"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
print("Primary result:")
|
||||||
|
print(result)
|
||||||
|
print("\n" + "="*50 + "\n")
|
||||||
|
|
||||||
|
if hasattr(task, 'output') and task.output.completion_metadata:
|
||||||
|
generations = task.output.get_generations()
|
||||||
|
if generations:
|
||||||
|
print(f"Generated {len(generations)} alternatives:")
|
||||||
|
for i, generation in enumerate(generations, 1):
|
||||||
|
print(f"\nGeneration {i}:")
|
||||||
|
print(generation)
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
|
def example_xml_extraction():
|
||||||
|
"""Example of extracting structured content from agent output."""
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Problem Solver",
|
||||||
|
goal="Solve problems systematically",
|
||||||
|
backstory="You think step by step and show your reasoning",
|
||||||
|
llm=LLM(model="gpt-3.5-turbo")
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="""
|
||||||
|
Solve this problem: How can we reduce energy consumption in an office building?
|
||||||
|
|
||||||
|
Please structure your response with:
|
||||||
|
- <thinking> tags for your internal reasoning
|
||||||
|
- <analysis> tags for your analysis of the problem
|
||||||
|
- <solution> tags for your proposed solution
|
||||||
|
""",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A structured solution with reasoning"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
print("Full agent output:")
|
||||||
|
print(result)
|
||||||
|
print("\n" + "="*50 + "\n")
|
||||||
|
|
||||||
|
thinking = extract_xml_content(result, "thinking")
|
||||||
|
analysis = extract_xml_content(result, "analysis")
|
||||||
|
solution = extract_xml_content(result, "solution")
|
||||||
|
|
||||||
|
if thinking:
|
||||||
|
print("Agent's thinking process:")
|
||||||
|
print(thinking)
|
||||||
|
print("\n" + "-"*30 + "\n")
|
||||||
|
|
||||||
|
if analysis:
|
||||||
|
print("Problem analysis:")
|
||||||
|
print(analysis)
|
||||||
|
print("\n" + "-"*30 + "\n")
|
||||||
|
|
||||||
|
if solution:
|
||||||
|
print("Proposed solution:")
|
||||||
|
print(solution)
|
||||||
|
|
||||||
|
|
||||||
|
def example_logprobs_analysis():
|
||||||
|
"""Example of accessing log probabilities for analysis."""
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
logprobs=5, # Request top 5 log probabilities
|
||||||
|
top_logprobs=3, # Show top 3 alternatives
|
||||||
|
return_full_completion=True
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Decision Analyst",
|
||||||
|
goal="Make confident decisions",
|
||||||
|
backstory="You analyze confidence levels in your responses",
|
||||||
|
llm=llm,
|
||||||
|
return_completion_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Should we invest in renewable energy? Give a yes/no answer with confidence.",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A clear yes/no decision"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
print("Decision:")
|
||||||
|
print(result)
|
||||||
|
print("\n" + "="*50 + "\n")
|
||||||
|
|
||||||
|
if hasattr(task, 'output') and task.output.completion_metadata:
|
||||||
|
logprobs = task.output.get_logprobs()
|
||||||
|
usage = task.output.get_usage_metrics()
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
print("Confidence analysis (log probabilities):")
|
||||||
|
print(f"Available logprobs data: {len(logprobs)} choices")
|
||||||
|
|
||||||
|
if usage:
|
||||||
|
print("\nToken usage:")
|
||||||
|
print(f"Prompt tokens: {usage.get('prompt_tokens', 'N/A')}")
|
||||||
|
print(f"Completion tokens: {usage.get('completion_tokens', 'N/A')}")
|
||||||
|
print(f"Total tokens: {usage.get('total_tokens', 'N/A')}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=== CrewAI LLM Generations and XML Extraction Examples ===\n")
|
||||||
|
|
||||||
|
print("1. Multiple Generations Example:")
|
||||||
|
print("-" * 40)
|
||||||
|
try:
|
||||||
|
example_multiple_generations()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Example requires actual LLM API access: {e}")
|
||||||
|
|
||||||
|
print("\n\n2. XML Content Extraction Example:")
|
||||||
|
print("-" * 40)
|
||||||
|
try:
|
||||||
|
example_xml_extraction()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Example requires actual LLM API access: {e}")
|
||||||
|
|
||||||
|
print("\n\n3. Log Probabilities Analysis Example:")
|
||||||
|
print("-" * 40)
|
||||||
|
try:
|
||||||
|
example_logprobs_analysis()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Example requires actual LLM API access: {e}")
|
||||||
|
|
||||||
|
print("\n=== Examples completed ===")
|
||||||
@@ -88,6 +88,22 @@ class Agent(BaseAgent):
|
|||||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
|
llm_n: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of generations to request from the LLM.",
|
||||||
|
)
|
||||||
|
llm_logprobs: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of log probabilities to return from the LLM.",
|
||||||
|
)
|
||||||
|
llm_top_logprobs: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of top log probabilities to return from the LLM.",
|
||||||
|
)
|
||||||
|
return_completion_metadata: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to return full completion metadata including generations and logprobs.",
|
||||||
|
)
|
||||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
@@ -179,6 +195,15 @@ class Agent(BaseAgent):
|
|||||||
):
|
):
|
||||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||||
|
|
||||||
|
if hasattr(self.llm, 'n') and self.llm_n is not None:
|
||||||
|
self.llm.n = self.llm_n
|
||||||
|
if hasattr(self.llm, 'logprobs') and self.llm_logprobs is not None:
|
||||||
|
self.llm.logprobs = self.llm_logprobs
|
||||||
|
if hasattr(self.llm, 'top_logprobs') and self.llm_top_logprobs is not None:
|
||||||
|
self.llm.top_logprobs = self.llm_top_logprobs
|
||||||
|
if hasattr(self.llm, 'return_full_completion'):
|
||||||
|
self.llm.return_full_completion = self.return_completion_metadata
|
||||||
|
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
self._setup_agent_executor()
|
self._setup_agent_executor()
|
||||||
|
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ def handle_user_input(
|
|||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": final_response})
|
messages.append({"role": "assistant", "content": str(final_response)})
|
||||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -92,6 +92,9 @@ class LiteAgentOutput(BaseModel):
|
|||||||
usage_metrics: Optional[Dict[str, Any]] = Field(
|
usage_metrics: Optional[Dict[str, Any]] = Field(
|
||||||
description="Token usage metrics for this execution", default=None
|
description="Token usage metrics for this execution", default=None
|
||||||
)
|
)
|
||||||
|
completion_metadata: Optional[Dict[str, Any]] = Field(
|
||||||
|
description="Full completion metadata including generations and logprobs", default=None
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert pydantic_output to a dictionary."""
|
"""Convert pydantic_output to a dictionary."""
|
||||||
@@ -99,6 +102,40 @@ class LiteAgentOutput(BaseModel):
|
|||||||
return self.pydantic.model_dump()
|
return self.pydantic.model_dump()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def get_generations(self) -> Optional[List[str]]:
|
||||||
|
"""Get all generations from completion metadata."""
|
||||||
|
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
for choice in self.completion_metadata["choices"]:
|
||||||
|
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||||
|
generations.append(choice.message.content or "")
|
||||||
|
elif isinstance(choice, dict) and "message" in choice:
|
||||||
|
generations.append(choice["message"].get("content", ""))
|
||||||
|
|
||||||
|
return generations if generations else None
|
||||||
|
|
||||||
|
def get_logprobs(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get log probabilities from completion metadata."""
|
||||||
|
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logprobs_list = []
|
||||||
|
for choice in self.completion_metadata["choices"]:
|
||||||
|
if hasattr(choice, "logprobs") and choice.logprobs:
|
||||||
|
logprobs_list.append(choice.logprobs)
|
||||||
|
elif isinstance(choice, dict) and "logprobs" in choice:
|
||||||
|
logprobs_list.append(choice["logprobs"])
|
||||||
|
|
||||||
|
return logprobs_list if logprobs_list else None
|
||||||
|
|
||||||
|
def get_usage_metrics_from_completion(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get token usage metrics from completion metadata."""
|
||||||
|
if not self.completion_metadata:
|
||||||
|
return None
|
||||||
|
return self.completion_metadata.get("usage")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""String representation of the output."""
|
"""String representation of the output."""
|
||||||
if self.pydantic:
|
if self.pydantic:
|
||||||
|
|||||||
@@ -311,6 +311,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
return_full_completion: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -337,6 +338,7 @@ class LLM(BaseLLM):
|
|||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.return_full_completion = return_full_completion
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -419,16 +421,18 @@ class LLM(BaseLLM):
|
|||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
return_full_completion: bool = False,
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""Handle a streaming response from the LLM.
|
"""Handle a streaming response from the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Parameters for the completion call
|
params: Parameters for the completion call
|
||||||
callbacks: Optional list of callback functions
|
callbacks: Optional list of callback functions
|
||||||
available_functions: Dict of available functions
|
available_functions: Dict of available functions
|
||||||
|
return_full_completion: Whether to return full completion object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The complete response text
|
Union[str, Dict[str, Any]]: The complete response text or full completion object
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If no content is received from the streaming response
|
Exception: If no content is received from the streaming response
|
||||||
@@ -626,11 +630,46 @@ class LLM(BaseLLM):
|
|||||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||||
# Emit completion event and return response
|
# Emit completion event and return response
|
||||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||||
|
|
||||||
|
if return_full_completion:
|
||||||
|
accumulated_choices = []
|
||||||
|
if last_chunk and hasattr(last_chunk, "choices"):
|
||||||
|
for choice in last_chunk.choices:
|
||||||
|
accumulated_choices.append({
|
||||||
|
"message": {"content": full_response},
|
||||||
|
"finish_reason": getattr(choice, "finish_reason", None),
|
||||||
|
"index": getattr(choice, "index", 0),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
accumulated_choices = [{"message": {"content": full_response}, "finish_reason": "stop", "index": 0}]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": full_response,
|
||||||
|
"choices": accumulated_choices,
|
||||||
|
"usage": usage_info,
|
||||||
|
"model": params.get("model"),
|
||||||
|
"created": getattr(last_chunk, "created", None) if last_chunk else None,
|
||||||
|
"id": getattr(last_chunk, "id", None) if last_chunk else None,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None,
|
||||||
|
}
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
# --- 9) Handle tool calls if present
|
# --- 9) Handle tool calls if present
|
||||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||||
if tool_result is not None:
|
if tool_result is not None:
|
||||||
|
if return_full_completion:
|
||||||
|
return {
|
||||||
|
"content": tool_result,
|
||||||
|
"choices": [{"message": {"content": tool_result}}],
|
||||||
|
"usage": usage_info,
|
||||||
|
"model": params.get("model"),
|
||||||
|
"created": getattr(last_chunk, "created", None) if last_chunk else None,
|
||||||
|
"id": getattr(last_chunk, "id", None) if last_chunk else None,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None,
|
||||||
|
}
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
# --- 10) Log token usage if available in streaming mode
|
# --- 10) Log token usage if available in streaming mode
|
||||||
@@ -638,6 +677,30 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
# --- 11) Emit completion event and return response
|
# --- 11) Emit completion event and return response
|
||||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||||
|
|
||||||
|
if return_full_completion:
|
||||||
|
accumulated_choices = []
|
||||||
|
if last_chunk and hasattr(last_chunk, "choices"):
|
||||||
|
for choice in last_chunk.choices:
|
||||||
|
accumulated_choices.append({
|
||||||
|
"message": {"content": full_response},
|
||||||
|
"finish_reason": getattr(choice, "finish_reason", None),
|
||||||
|
"index": getattr(choice, "index", 0),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
accumulated_choices = [{"message": {"content": full_response}, "finish_reason": "stop", "index": 0}]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"content": full_response,
|
||||||
|
"choices": accumulated_choices,
|
||||||
|
"usage": usage_info,
|
||||||
|
"model": params.get("model"),
|
||||||
|
"created": getattr(last_chunk, "created", None) if last_chunk else None,
|
||||||
|
"id": getattr(last_chunk, "id", None) if last_chunk else None,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None,
|
||||||
|
}
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
except ContextWindowExceededError as e:
|
except ContextWindowExceededError as e:
|
||||||
@@ -748,16 +811,18 @@ class LLM(BaseLLM):
|
|||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
return_full_completion: bool = False,
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""Handle a non-streaming response from the LLM.
|
"""Handle a non-streaming response from the LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: Parameters for the completion call
|
params: Parameters for the completion call
|
||||||
callbacks: Optional list of callback functions
|
callbacks: Optional list of callback functions
|
||||||
available_functions: Dict of available functions
|
available_functions: Dict of available functions
|
||||||
|
return_full_completion: Whether to return full completion object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The response text
|
Union[str, Dict[str, Any]]: The response text or full completion object
|
||||||
"""
|
"""
|
||||||
# --- 1) Make the completion call
|
# --- 1) Make the completion call
|
||||||
try:
|
try:
|
||||||
@@ -793,18 +858,51 @@ class LLM(BaseLLM):
|
|||||||
# --- 4) Check for tool calls
|
# --- 4) Check for tool calls
|
||||||
tool_calls = getattr(response_message, "tool_calls", [])
|
tool_calls = getattr(response_message, "tool_calls", [])
|
||||||
|
|
||||||
# --- 5) If no tool calls or no available functions, return the text response directly
|
# --- 5) If no tool calls or no available functions, return the response
|
||||||
if not tool_calls or not available_functions:
|
if not tool_calls or not available_functions:
|
||||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||||
|
if return_full_completion:
|
||||||
|
return {
|
||||||
|
"content": text_response,
|
||||||
|
"choices": response.choices,
|
||||||
|
"usage": getattr(response, "usage", None),
|
||||||
|
"model": response.model,
|
||||||
|
"created": getattr(response, "created", None),
|
||||||
|
"id": getattr(response, "id", None),
|
||||||
|
"object": getattr(response, "object", "chat.completion"),
|
||||||
|
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||||
|
}
|
||||||
return text_response
|
return text_response
|
||||||
|
|
||||||
# --- 6) Handle tool calls if present
|
# --- 6) Handle tool calls if present
|
||||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||||
if tool_result is not None:
|
if tool_result is not None:
|
||||||
|
if return_full_completion:
|
||||||
|
return {
|
||||||
|
"content": tool_result,
|
||||||
|
"choices": response.choices,
|
||||||
|
"usage": getattr(response, "usage", None),
|
||||||
|
"model": response.model,
|
||||||
|
"created": getattr(response, "created", None),
|
||||||
|
"id": getattr(response, "id", None),
|
||||||
|
"object": getattr(response, "object", "chat.completion"),
|
||||||
|
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||||
|
}
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
# --- 7) If tool call handling didn't return a result, emit completion event and return text response
|
# --- 7) If tool call handling didn't return a result, emit completion event and return response
|
||||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||||
|
if return_full_completion:
|
||||||
|
return {
|
||||||
|
"content": text_response,
|
||||||
|
"choices": response.choices,
|
||||||
|
"usage": getattr(response, "usage", None),
|
||||||
|
"model": response.model,
|
||||||
|
"created": getattr(response, "created", None),
|
||||||
|
"id": getattr(response, "id", None),
|
||||||
|
"object": getattr(response, "object", "chat.completion"),
|
||||||
|
"system_fingerprint": getattr(response, "system_fingerprint", None),
|
||||||
|
}
|
||||||
return text_response
|
return text_response
|
||||||
|
|
||||||
def _handle_tool_call(
|
def _handle_tool_call(
|
||||||
@@ -889,7 +987,8 @@ class LLM(BaseLLM):
|
|||||||
tools: Optional[List[dict]] = None,
|
tools: Optional[List[dict]] = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> Union[str, Any]:
|
return_full_completion: Optional[bool] = None,
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""High-level LLM call method.
|
"""High-level LLM call method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -903,10 +1002,11 @@ class LLM(BaseLLM):
|
|||||||
during and after the LLM call.
|
during and after the LLM call.
|
||||||
available_functions: Optional dict mapping function names to callables
|
available_functions: Optional dict mapping function names to callables
|
||||||
that can be invoked by the LLM.
|
that can be invoked by the LLM.
|
||||||
|
return_full_completion: Optional override for returning full completion object
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[str, Any]: Either a text response from the LLM (str) or
|
Union[str, Dict[str, Any]]: Either a text response from the LLM (str) or
|
||||||
the result of a tool function call (Any).
|
the full completion object with generations and metadata.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If messages format is invalid
|
TypeError: If messages format is invalid
|
||||||
@@ -944,17 +1044,20 @@ class LLM(BaseLLM):
|
|||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# --- 6) Prepare parameters for the completion call
|
# --- 6) Determine if we should return full completion
|
||||||
|
should_return_full = return_full_completion if return_full_completion is not None else self.return_full_completion
|
||||||
|
|
||||||
|
# --- 7) Prepare parameters for the completion call
|
||||||
params = self._prepare_completion_params(messages, tools)
|
params = self._prepare_completion_params(messages, tools)
|
||||||
|
|
||||||
# --- 7) Make the completion call and handle response
|
# --- 8) Make the completion call and handle response
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_response(
|
return self._handle_streaming_response(
|
||||||
params, callbacks, available_functions
|
params, callbacks, available_functions, should_return_full
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._handle_non_streaming_response(
|
return self._handle_non_streaming_response(
|
||||||
params, callbacks, available_functions
|
params, callbacks, available_functions, should_return_full
|
||||||
)
|
)
|
||||||
|
|
||||||
except LLMContextLengthExceededException:
|
except LLMContextLengthExceededException:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
@@ -26,6 +26,9 @@ class TaskOutput(BaseModel):
|
|||||||
output_format: OutputFormat = Field(
|
output_format: OutputFormat = Field(
|
||||||
description="Output format of the task", default=OutputFormat.RAW
|
description="Output format of the task", default=OutputFormat.RAW
|
||||||
)
|
)
|
||||||
|
completion_metadata: Optional[Dict[str, Any]] = Field(
|
||||||
|
description="Full completion metadata including generations and logprobs", default=None
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_summary(self):
|
def set_summary(self):
|
||||||
@@ -56,6 +59,40 @@ class TaskOutput(BaseModel):
|
|||||||
output_dict.update(self.pydantic.model_dump())
|
output_dict.update(self.pydantic.model_dump())
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
|
def get_generations(self) -> Optional[List[str]]:
|
||||||
|
"""Get all generations from completion metadata."""
|
||||||
|
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
generations = []
|
||||||
|
for choice in self.completion_metadata["choices"]:
|
||||||
|
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||||
|
generations.append(choice.message.content or "")
|
||||||
|
elif isinstance(choice, dict) and "message" in choice:
|
||||||
|
generations.append(choice["message"].get("content", ""))
|
||||||
|
|
||||||
|
return generations if generations else None
|
||||||
|
|
||||||
|
def get_logprobs(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Get log probabilities from completion metadata."""
|
||||||
|
if not self.completion_metadata or "choices" not in self.completion_metadata:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logprobs_list = []
|
||||||
|
for choice in self.completion_metadata["choices"]:
|
||||||
|
if hasattr(choice, "logprobs") and choice.logprobs:
|
||||||
|
logprobs_list.append(choice.logprobs)
|
||||||
|
elif isinstance(choice, dict) and "logprobs" in choice:
|
||||||
|
logprobs_list.append(choice["logprobs"])
|
||||||
|
|
||||||
|
return logprobs_list if logprobs_list else None
|
||||||
|
|
||||||
|
def get_usage_metrics(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get token usage metrics from completion metadata."""
|
||||||
|
if not self.completion_metadata:
|
||||||
|
return None
|
||||||
|
return self.completion_metadata.get("usage")
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
if self.pydantic:
|
if self.pydantic:
|
||||||
return str(self.pydantic)
|
return str(self.pydantic)
|
||||||
|
|||||||
@@ -109,7 +109,11 @@ def handle_max_iterations_exceeded(
|
|||||||
)
|
)
|
||||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
|
||||||
formatted_answer = format_answer(answer)
|
text_answer = answer
|
||||||
|
if isinstance(answer, dict) and "content" in answer:
|
||||||
|
text_answer = answer["content"]
|
||||||
|
|
||||||
|
formatted_answer = format_answer(str(text_answer))
|
||||||
# Return the formatted answer, regardless of its type
|
# Return the formatted answer, regardless of its type
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
@@ -145,42 +149,64 @@ def get_llm_response(
|
|||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
callbacks: List[Any],
|
callbacks: List[Any],
|
||||||
printer: Printer,
|
printer: Printer,
|
||||||
) -> str:
|
return_full_completion: bool = False,
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
"""Call the LLM and return the response, handling any invalid responses."""
|
"""Call the LLM and return the response, handling any invalid responses."""
|
||||||
try:
|
try:
|
||||||
answer = llm.call(
|
from crewai.llm import LLM
|
||||||
messages,
|
if isinstance(llm, LLM) and return_full_completion:
|
||||||
callbacks=callbacks,
|
answer = llm.call(
|
||||||
)
|
messages,
|
||||||
|
callbacks=callbacks,
|
||||||
|
return_full_completion=return_full_completion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer = llm.call(
|
||||||
|
messages,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
printer.print(
|
printer.print(
|
||||||
content=f"Error during LLM call: {e}",
|
content=f"Error during LLM call: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
if not answer:
|
|
||||||
printer.print(
|
if return_full_completion:
|
||||||
content="Received None or empty response from LLM call.",
|
if not answer or (isinstance(answer, dict) and not answer.get("content")):
|
||||||
color="red",
|
printer.print(
|
||||||
)
|
content="Received None or empty response from LLM call.",
|
||||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
color="red",
|
||||||
|
)
|
||||||
return answer
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
return answer
|
||||||
|
else:
|
||||||
|
if not answer:
|
||||||
|
printer.print(
|
||||||
|
content="Received None or empty response from LLM call.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(
|
def process_llm_response(
|
||||||
answer: str, use_stop_words: bool
|
answer: Union[str, Dict[str, Any]], use_stop_words: bool
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
||||||
|
text_answer = answer
|
||||||
|
if isinstance(answer, dict):
|
||||||
|
text_answer = answer.get("content", "")
|
||||||
|
|
||||||
if not use_stop_words:
|
if not use_stop_words:
|
||||||
try:
|
try:
|
||||||
# Preliminary parsing to check for errors.
|
# Preliminary parsing to check for errors.
|
||||||
format_answer(answer)
|
format_answer(str(text_answer))
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||||
answer = answer.split("Observation:")[0].strip()
|
text_answer = str(text_answer).split("Observation:")[0].strip()
|
||||||
|
|
||||||
return format_answer(answer)
|
return format_answer(str(text_answer))
|
||||||
|
|
||||||
|
|
||||||
def handle_agent_action_core(
|
def handle_agent_action_core(
|
||||||
|
|||||||
@@ -260,14 +260,22 @@ class AgentReasoning:
|
|||||||
available_functions={"create_reasoning_plan": _create_reasoning_plan},
|
available_functions={"create_reasoning_plan": _create_reasoning_plan},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Function calling response: {response[:100]}...")
|
if isinstance(response, dict):
|
||||||
|
response_str = str(response)
|
||||||
try:
|
self.logger.debug(f"Function calling response: {response_str[:100]}...")
|
||||||
result = json.loads(response)
|
|
||||||
if "plan" in result and "ready" in result:
|
if "plan" in response and "ready" in response:
|
||||||
return result["plan"], result["ready"]
|
return response["plan"], response["ready"]
|
||||||
except (json.JSONDecodeError, KeyError):
|
else:
|
||||||
pass
|
response_str = str(response)
|
||||||
|
self.logger.debug(f"Function calling response: {response_str[:100]}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = json.loads(response_str)
|
||||||
|
if "plan" in result and "ready" in result:
|
||||||
|
return result["plan"], result["ready"]
|
||||||
|
except (json.JSONDecodeError, KeyError):
|
||||||
|
pass
|
||||||
|
|
||||||
response_str = str(response)
|
response_str = str(response)
|
||||||
return response_str, "READY: I am ready to execute the task." in response_str
|
return response_str, "READY: I am ready to execute the task." in response_str
|
||||||
|
|||||||
133
src/crewai/utilities/xml_parser.py
Normal file
133
src/crewai/utilities/xml_parser.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
def extract_xml_content(text: str, tag: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract content from a specific XML tag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search in
|
||||||
|
tag: The XML tag name to extract (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The content inside the first occurrence of the tag, or None if not found
|
||||||
|
"""
|
||||||
|
pattern = rf'<{re.escape(tag)}>(.*?)</{re.escape(tag)}>'
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
return match.group(1).strip() if match else None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_all_xml_content(text: str, tag: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract content from all occurrences of a specific XML tag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search in
|
||||||
|
tag: The XML tag name to extract (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of content strings from all occurrences of the tag
|
||||||
|
"""
|
||||||
|
pattern = rf'<{re.escape(tag)}>(.*?)</{re.escape(tag)}>'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
return [match.strip() for match in matches]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_multiple_xml_tags(text: str, tags: List[str]) -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Extract content from multiple XML tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search in
|
||||||
|
tags: List of XML tag names to extract (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping tag names to their content (or None if not found)
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for tag in tags:
|
||||||
|
result[tag] = extract_xml_content(text, tag)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def extract_multiple_xml_tags_all(text: str, tags: List[str]) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Extract content from all occurrences of multiple XML tags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search in
|
||||||
|
tags: List of XML tag names to extract (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping tag names to lists of their content
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for tag in tags:
|
||||||
|
result[tag] = extract_all_xml_content(text, tag)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def extract_xml_with_attributes(text: str, tag: str) -> List[Dict[str, Union[str, Dict[str, str]]]]:
|
||||||
|
"""
|
||||||
|
Extract XML tags with their attributes and content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to search in
|
||||||
|
tag: The XML tag name to extract (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries containing 'content' and 'attributes' for each occurrence
|
||||||
|
"""
|
||||||
|
pattern = rf'<{re.escape(tag)}([^>]*)>(.*?)</{re.escape(tag)}>'
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for attrs_str, content in matches:
|
||||||
|
attributes = {}
|
||||||
|
if attrs_str.strip():
|
||||||
|
attr_pattern = r'(\w+)=["\']([^"\']*)["\']'
|
||||||
|
attributes = dict(re.findall(attr_pattern, attrs_str))
|
||||||
|
|
||||||
|
result.append({
|
||||||
|
'content': content.strip(),
|
||||||
|
'attributes': attributes
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def remove_xml_tags(text: str, tags: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Remove specific XML tags and their content from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to process
|
||||||
|
tags: List of XML tag names to remove (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text with the specified XML tags and their content removed
|
||||||
|
"""
|
||||||
|
result = text
|
||||||
|
for tag in tags:
|
||||||
|
pattern = rf'<{re.escape(tag)}[^>]*>.*?</{re.escape(tag)}>'
|
||||||
|
result = re.sub(pattern, '', result, flags=re.DOTALL)
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def strip_xml_tags_keep_content(text: str, tags: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Remove specific XML tags but keep their content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to process
|
||||||
|
tags: List of XML tag names to strip (without angle brackets)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text with the specified XML tags removed but content preserved
|
||||||
|
"""
|
||||||
|
result = text
|
||||||
|
for tag in tags:
|
||||||
|
pattern = rf'<{re.escape(tag)}[^>]*>(.*?)</{re.escape(tag)}>'
|
||||||
|
result = re.sub(pattern, r'\1', result, flags=re.DOTALL)
|
||||||
|
return result.strip()
|
||||||
@@ -162,6 +162,7 @@ def test_task_callback_returns_task_output():
|
|||||||
"name": None,
|
"name": None,
|
||||||
"expected_output": "Bullet point list of 5 interesting ideas.",
|
"expected_output": "Bullet point list of 5 interesting ideas.",
|
||||||
"output_format": OutputFormat.RAW,
|
"output_format": OutputFormat.RAW,
|
||||||
|
"completion_metadata": None,
|
||||||
}
|
}
|
||||||
assert output_dict == expected_output
|
assert output_dict == expected_output
|
||||||
|
|
||||||
|
|||||||
159
tests/test_integration_llm_features.py
Normal file
159
tests/test_integration_llm_features.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
from crewai.lite_agent import LiteAgent
|
||||||
|
from crewai.utilities.xml_parser import extract_xml_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationLLMFeatures:
|
||||||
|
"""Integration tests for LLM features with agents and tasks."""
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_agent_with_multiple_generations(self, mock_completion):
|
||||||
|
"""Test agent execution with multiple generations."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [
|
||||||
|
Mock(message=Mock(content="Generation 1")),
|
||||||
|
Mock(message=Mock(content="Generation 2")),
|
||||||
|
Mock(message=Mock(content="Generation 3")),
|
||||||
|
]
|
||||||
|
mock_response.usage = {"prompt_tokens": 20, "completion_tokens": 30}
|
||||||
|
mock_response.model = "gpt-3.5-turbo"
|
||||||
|
mock_response.created = 1234567890
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
mock_response.object = "chat.completion"
|
||||||
|
mock_response.system_fingerprint = "test-fingerprint"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-3.5-turbo", n=3, return_full_completion=True)
|
||||||
|
agent = Agent(
|
||||||
|
role="writer",
|
||||||
|
goal="write content",
|
||||||
|
backstory="You are a writer",
|
||||||
|
llm=llm,
|
||||||
|
return_completion_metadata=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Write a short story",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A short story",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(agent, 'agent_executor') as mock_executor:
|
||||||
|
mock_executor.invoke.return_value = {"output": "Generation 1"}
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
assert result == "Generation 1"
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_lite_agent_with_xml_extraction(self, mock_completion):
|
||||||
|
"""Test LiteAgent with XML content extraction."""
|
||||||
|
response_with_xml = """
|
||||||
|
<thinking>
|
||||||
|
I need to analyze this problem step by step.
|
||||||
|
First, I'll consider the requirements.
|
||||||
|
</thinking>
|
||||||
|
|
||||||
|
Based on my analysis, here's the solution: The answer is 42.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content=response_with_xml))]
|
||||||
|
mock_response.usage = {"prompt_tokens": 15, "completion_tokens": 25}
|
||||||
|
mock_response.model = "gpt-3.5-turbo"
|
||||||
|
mock_response.created = 1234567890
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
mock_response.object = "chat.completion"
|
||||||
|
mock_response.system_fingerprint = "test-fingerprint"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
lite_agent = LiteAgent(
|
||||||
|
role="analyst",
|
||||||
|
goal="analyze problems",
|
||||||
|
backstory="You are an analyst",
|
||||||
|
llm=LLM(model="gpt-3.5-turbo", return_full_completion=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(lite_agent, '_invoke_loop') as mock_invoke:
|
||||||
|
from crewai.agents.agent_action import AgentFinish
|
||||||
|
mock_agent_finish = AgentFinish(output=response_with_xml, text=response_with_xml)
|
||||||
|
mock_invoke.return_value = mock_agent_finish
|
||||||
|
|
||||||
|
result = lite_agent.kickoff("Analyze this problem")
|
||||||
|
|
||||||
|
thinking_content = extract_xml_content(result.raw, "thinking")
|
||||||
|
assert thinking_content is not None
|
||||||
|
assert "step by step" in thinking_content
|
||||||
|
assert "requirements" in thinking_content
|
||||||
|
|
||||||
|
def test_xml_parser_with_complex_agent_output(self):
|
||||||
|
"""Test XML parser with complex agent output containing multiple tags."""
|
||||||
|
complex_output = """
|
||||||
|
<thinking>
|
||||||
|
This is a complex problem that requires careful analysis.
|
||||||
|
I need to break it down into steps.
|
||||||
|
</thinking>
|
||||||
|
|
||||||
|
<reasoning>
|
||||||
|
Step 1: Understand the requirements
|
||||||
|
Step 2: Analyze the constraints
|
||||||
|
Step 3: Develop a solution
|
||||||
|
</reasoning>
|
||||||
|
|
||||||
|
<conclusion>
|
||||||
|
The best approach is to use a systematic methodology.
|
||||||
|
</conclusion>
|
||||||
|
|
||||||
|
Final answer: Use the systematic approach outlined above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
thinking = extract_xml_content(complex_output, "thinking")
|
||||||
|
reasoning = extract_xml_content(complex_output, "reasoning")
|
||||||
|
conclusion = extract_xml_content(complex_output, "conclusion")
|
||||||
|
|
||||||
|
assert thinking is not None
|
||||||
|
assert "complex problem" in thinking
|
||||||
|
assert reasoning is not None
|
||||||
|
assert "Step 1" in reasoning
|
||||||
|
assert "Step 2" in reasoning
|
||||||
|
assert "Step 3" in reasoning
|
||||||
|
assert conclusion is not None
|
||||||
|
assert "systematic methodology" in conclusion
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_crew_with_llm_parameters(self, mock_completion):
|
||||||
|
"""Test crew execution with LLM parameters."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock(message=Mock(content="Test response"))]
|
||||||
|
mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5}
|
||||||
|
mock_response.model = "gpt-3.5-turbo"
|
||||||
|
mock_response.created = 1234567890
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
mock_response.object = "chat.completion"
|
||||||
|
mock_response.system_fingerprint = "test-fingerprint"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="analyst",
|
||||||
|
goal="analyze data",
|
||||||
|
backstory="You are an analyst",
|
||||||
|
llm_n=2,
|
||||||
|
llm_logprobs=5,
|
||||||
|
return_completion_metadata=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Analyze the data",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="Analysis results",
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
with patch.object(crew, '_run_sequential_process') as mock_run:
|
||||||
|
from crewai.crews.crew_output import CrewOutput
|
||||||
|
mock_output = CrewOutput(raw="Test response")
|
||||||
|
mock_run.return_value = mock_output
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
assert result is not None
|
||||||
226
tests/test_llm_generations_logprobs.py
Normal file
226
tests/test_llm_generations_logprobs.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from crewai import Agent, LLM
|
||||||
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
from crewai.lite_agent import LiteAgentOutput
|
||||||
|
from crewai.utilities.xml_parser import (
|
||||||
|
extract_xml_content,
|
||||||
|
extract_all_xml_content,
|
||||||
|
extract_multiple_xml_tags,
|
||||||
|
extract_multiple_xml_tags_all,
|
||||||
|
extract_xml_with_attributes,
|
||||||
|
remove_xml_tags,
|
||||||
|
strip_xml_tags_keep_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMGenerationsLogprobs:
|
||||||
|
"""Test suite for LLM generations and logprobs functionality."""
|
||||||
|
|
||||||
|
def test_llm_with_n_parameter(self):
|
||||||
|
"""Test that LLM accepts n parameter for multiple generations."""
|
||||||
|
llm = LLM(model="gpt-3.5-turbo", n=3)
|
||||||
|
assert llm.n == 3
|
||||||
|
|
||||||
|
def test_llm_with_logprobs_parameter(self):
|
||||||
|
"""Test that LLM accepts logprobs parameter."""
|
||||||
|
llm = LLM(model="gpt-3.5-turbo", logprobs=5)
|
||||||
|
assert llm.logprobs == 5
|
||||||
|
|
||||||
|
def test_llm_with_return_full_completion(self):
|
||||||
|
"""Test that LLM accepts return_full_completion parameter."""
|
||||||
|
llm = LLM(model="gpt-3.5-turbo", return_full_completion=True)
|
||||||
|
assert llm.return_full_completion is True
|
||||||
|
|
||||||
|
def test_agent_with_llm_parameters(self):
|
||||||
|
"""Test that Agent accepts LLM generation parameters."""
|
||||||
|
agent = Agent(
|
||||||
|
role="test",
|
||||||
|
goal="test",
|
||||||
|
backstory="test",
|
||||||
|
llm_n=3,
|
||||||
|
llm_logprobs=5,
|
||||||
|
llm_top_logprobs=3,
|
||||||
|
return_completion_metadata=True,
|
||||||
|
)
|
||||||
|
assert agent.llm_n == 3
|
||||||
|
assert agent.llm_logprobs == 5
|
||||||
|
assert agent.llm_top_logprobs == 3
|
||||||
|
assert agent.return_completion_metadata is True
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_llm_call_returns_full_completion(self, mock_completion):
|
||||||
|
"""Test that LLM.call can return full completion object."""
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message.content = "Test response"
|
||||||
|
mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5}
|
||||||
|
mock_response.model = "gpt-3.5-turbo"
|
||||||
|
mock_response.created = 1234567890
|
||||||
|
mock_response.id = "test-id"
|
||||||
|
mock_response.object = "chat.completion"
|
||||||
|
mock_response.system_fingerprint = "test-fingerprint"
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-3.5-turbo", return_full_completion=True)
|
||||||
|
result = llm.call("Test message")
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["content"] == "Test response"
|
||||||
|
assert "choices" in result
|
||||||
|
assert "usage" in result
|
||||||
|
assert result["model"] == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
def test_task_output_completion_metadata(self):
|
||||||
|
"""Test TaskOutput with completion metadata."""
|
||||||
|
mock_choices = [
|
||||||
|
Mock(message=Mock(content="Generation 1")),
|
||||||
|
Mock(message=Mock(content="Generation 2")),
|
||||||
|
]
|
||||||
|
mock_usage = {"prompt_tokens": 10, "completion_tokens": 15}
|
||||||
|
|
||||||
|
completion_metadata = {
|
||||||
|
"choices": mock_choices,
|
||||||
|
"usage": mock_usage,
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
}
|
||||||
|
|
||||||
|
task_output = TaskOutput(
|
||||||
|
description="Test task",
|
||||||
|
raw="Generation 1",
|
||||||
|
agent="test-agent",
|
||||||
|
completion_metadata=completion_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations = task_output.get_generations()
|
||||||
|
assert generations == ["Generation 1", "Generation 2"]
|
||||||
|
|
||||||
|
usage = task_output.get_usage_metrics()
|
||||||
|
assert usage == mock_usage
|
||||||
|
|
||||||
|
def test_lite_agent_output_completion_metadata(self):
|
||||||
|
"""Test LiteAgentOutput with completion metadata."""
|
||||||
|
mock_choices = [
|
||||||
|
Mock(message=Mock(content="Generation 1")),
|
||||||
|
Mock(message=Mock(content="Generation 2")),
|
||||||
|
]
|
||||||
|
mock_usage = {"prompt_tokens": 10, "completion_tokens": 15}
|
||||||
|
|
||||||
|
completion_metadata = {
|
||||||
|
"choices": mock_choices,
|
||||||
|
"usage": mock_usage,
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
}
|
||||||
|
|
||||||
|
output = LiteAgentOutput(
|
||||||
|
raw="Generation 1",
|
||||||
|
agent_role="test-agent",
|
||||||
|
completion_metadata=completion_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
generations = output.get_generations()
|
||||||
|
assert generations == ["Generation 1", "Generation 2"]
|
||||||
|
|
||||||
|
usage = output.get_usage_metrics_from_completion()
|
||||||
|
assert usage == mock_usage
|
||||||
|
|
||||||
|
|
||||||
|
class TestXMLParser:
|
||||||
|
"""Test suite for XML parsing functionality."""
|
||||||
|
|
||||||
|
def test_extract_xml_content_basic(self):
|
||||||
|
"""Test basic XML content extraction."""
|
||||||
|
text = "Some text <thinking>This is my thought</thinking> more text"
|
||||||
|
result = extract_xml_content(text, "thinking")
|
||||||
|
assert result == "This is my thought"
|
||||||
|
|
||||||
|
def test_extract_xml_content_not_found(self):
|
||||||
|
"""Test XML content extraction when tag not found."""
|
||||||
|
text = "Some text without the tag"
|
||||||
|
result = extract_xml_content(text, "thinking")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_extract_xml_content_multiline(self):
|
||||||
|
"""Test XML content extraction with multiline content."""
|
||||||
|
text = """Some text
|
||||||
|
<thinking>
|
||||||
|
This is a multiline
|
||||||
|
thought process
|
||||||
|
</thinking>
|
||||||
|
more text"""
|
||||||
|
result = extract_xml_content(text, "thinking")
|
||||||
|
assert "multiline" in result
|
||||||
|
assert "thought process" in result
|
||||||
|
|
||||||
|
def test_extract_all_xml_content(self):
|
||||||
|
"""Test extracting all occurrences of XML content."""
|
||||||
|
text = """
|
||||||
|
<thinking>First thought</thinking>
|
||||||
|
Some text
|
||||||
|
<thinking>Second thought</thinking>
|
||||||
|
"""
|
||||||
|
result = extract_all_xml_content(text, "thinking")
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0] == "First thought"
|
||||||
|
assert result[1] == "Second thought"
|
||||||
|
|
||||||
|
def test_extract_multiple_xml_tags(self):
|
||||||
|
"""Test extracting multiple different XML tags."""
|
||||||
|
text = """
|
||||||
|
<thinking>My thoughts</thinking>
|
||||||
|
<reasoning>My reasoning</reasoning>
|
||||||
|
<conclusion>My conclusion</conclusion>
|
||||||
|
"""
|
||||||
|
result = extract_multiple_xml_tags(text, ["thinking", "reasoning", "conclusion"])
|
||||||
|
assert result["thinking"] == "My thoughts"
|
||||||
|
assert result["reasoning"] == "My reasoning"
|
||||||
|
assert result["conclusion"] == "My conclusion"
|
||||||
|
|
||||||
|
def test_extract_multiple_xml_tags_all(self):
|
||||||
|
"""Test extracting all occurrences of multiple XML tags."""
|
||||||
|
text = """
|
||||||
|
<thinking>First thought</thinking>
|
||||||
|
<reasoning>First reasoning</reasoning>
|
||||||
|
<thinking>Second thought</thinking>
|
||||||
|
"""
|
||||||
|
result = extract_multiple_xml_tags_all(text, ["thinking", "reasoning"])
|
||||||
|
assert len(result["thinking"]) == 2
|
||||||
|
assert len(result["reasoning"]) == 1
|
||||||
|
assert result["thinking"][0] == "First thought"
|
||||||
|
assert result["thinking"][1] == "Second thought"
|
||||||
|
|
||||||
|
def test_extract_xml_with_attributes(self):
|
||||||
|
"""Test extracting XML with attributes."""
|
||||||
|
text = '<thinking type="deep" level="2">Complex thought</thinking>'
|
||||||
|
result = extract_xml_with_attributes(text, "thinking")
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["content"] == "Complex thought"
|
||||||
|
assert result[0]["attributes"]["type"] == "deep"
|
||||||
|
assert result[0]["attributes"]["level"] == "2"
|
||||||
|
|
||||||
|
def test_remove_xml_tags(self):
|
||||||
|
"""Test removing XML tags and their content."""
|
||||||
|
text = "Keep this <thinking>Remove this</thinking> and this"
|
||||||
|
result = remove_xml_tags(text, ["thinking"])
|
||||||
|
assert result == "Keep this and this"
|
||||||
|
|
||||||
|
def test_strip_xml_tags_keep_content(self):
|
||||||
|
"""Test stripping XML tags but keeping content."""
|
||||||
|
text = "Keep this <thinking>Keep this too</thinking> and this"
|
||||||
|
result = strip_xml_tags_keep_content(text, ["thinking"])
|
||||||
|
assert result == "Keep this Keep this too and this"
|
||||||
|
|
||||||
|
def test_nested_xml_tags(self):
|
||||||
|
"""Test handling of nested XML tags."""
|
||||||
|
text = "<outer>Before <inner>nested content</inner> after</outer>"
|
||||||
|
result = extract_xml_content(text, "outer")
|
||||||
|
assert "Before" in result
|
||||||
|
assert "nested content" in result
|
||||||
|
assert "after" in result
|
||||||
|
|
||||||
|
def test_xml_with_special_characters(self):
|
||||||
|
"""Test XML parsing with special characters."""
|
||||||
|
text = "<thinking>Content with & < > \" ' characters</thinking>"
|
||||||
|
result = extract_xml_content(text, "thinking")
|
||||||
|
assert "&" in result
|
||||||
|
assert "<" in result
|
||||||
|
assert ">" in result
|
||||||
161
tests/test_xml_parser_examples.py
Normal file
161
tests/test_xml_parser_examples.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from crewai.utilities.xml_parser import (
|
||||||
|
extract_xml_content,
|
||||||
|
extract_all_xml_content,
|
||||||
|
extract_multiple_xml_tags,
|
||||||
|
remove_xml_tags,
|
||||||
|
strip_xml_tags_keep_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestXMLParserExamples:
|
||||||
|
"""Test XML parser with realistic agent output examples."""
|
||||||
|
|
||||||
|
def test_agent_thinking_extraction(self):
|
||||||
|
"""Test extracting thinking content from agent output."""
|
||||||
|
agent_output = """
|
||||||
|
I need to solve this problem step by step.
|
||||||
|
|
||||||
|
<thinking>
|
||||||
|
Let me break this down:
|
||||||
|
1. First, I need to understand the requirements
|
||||||
|
2. Then, I'll analyze the constraints
|
||||||
|
3. Finally, I'll propose a solution
|
||||||
|
|
||||||
|
The key insight is that we need to balance efficiency with accuracy.
|
||||||
|
</thinking>
|
||||||
|
|
||||||
|
Based on my analysis, here's my recommendation: Use approach A.
|
||||||
|
"""
|
||||||
|
|
||||||
|
thinking = extract_xml_content(agent_output, "thinking")
|
||||||
|
assert thinking is not None
|
||||||
|
assert "break this down" in thinking
|
||||||
|
assert "requirements" in thinking
|
||||||
|
assert "constraints" in thinking
|
||||||
|
assert "efficiency with accuracy" in thinking
|
||||||
|
|
||||||
|
def test_multiple_reasoning_tags(self):
|
||||||
|
"""Test extracting multiple reasoning sections."""
|
||||||
|
agent_output = """
|
||||||
|
<reasoning>
|
||||||
|
Initial analysis shows three possible approaches.
|
||||||
|
</reasoning>
|
||||||
|
|
||||||
|
Let me explore each option:
|
||||||
|
|
||||||
|
<reasoning>
|
||||||
|
Option A: Fast but less accurate
|
||||||
|
Option B: Slow but very accurate
|
||||||
|
Option C: Balanced approach
|
||||||
|
</reasoning>
|
||||||
|
|
||||||
|
My final recommendation is Option C.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reasoning_sections = extract_all_xml_content(agent_output, "reasoning")
|
||||||
|
assert len(reasoning_sections) == 2
|
||||||
|
assert "three possible approaches" in reasoning_sections[0]
|
||||||
|
assert "Option A" in reasoning_sections[1]
|
||||||
|
assert "Option B" in reasoning_sections[1]
|
||||||
|
assert "Option C" in reasoning_sections[1]
|
||||||
|
|
||||||
|
def test_complex_agent_workflow(self):
|
||||||
|
"""Test complex agent output with multiple tag types."""
|
||||||
|
complex_output = """
|
||||||
|
<thinking>
|
||||||
|
This is a complex problem requiring systematic analysis.
|
||||||
|
I need to consider multiple factors.
|
||||||
|
</thinking>
|
||||||
|
|
||||||
|
<analysis>
|
||||||
|
Factor 1: Performance requirements
|
||||||
|
Factor 2: Cost constraints
|
||||||
|
Factor 3: Time limitations
|
||||||
|
</analysis>
|
||||||
|
|
||||||
|
<reasoning>
|
||||||
|
Given the analysis above, I believe we should prioritize performance
|
||||||
|
while keeping costs reasonable. Time is less critical in this case.
|
||||||
|
</reasoning>
|
||||||
|
|
||||||
|
<conclusion>
|
||||||
|
Recommend Solution X with performance optimizations.
|
||||||
|
</conclusion>
|
||||||
|
|
||||||
|
Final answer: Implement Solution X with the following optimizations...
|
||||||
|
"""
|
||||||
|
|
||||||
|
extracted = extract_multiple_xml_tags(
|
||||||
|
complex_output,
|
||||||
|
["thinking", "analysis", "reasoning", "conclusion"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extracted["thinking"] is not None
|
||||||
|
assert "systematic analysis" in extracted["thinking"]
|
||||||
|
|
||||||
|
assert extracted["analysis"] is not None
|
||||||
|
assert "Factor 1" in extracted["analysis"]
|
||||||
|
assert "Factor 2" in extracted["analysis"]
|
||||||
|
assert "Factor 3" in extracted["analysis"]
|
||||||
|
|
||||||
|
assert extracted["reasoning"] is not None
|
||||||
|
assert "prioritize performance" in extracted["reasoning"]
|
||||||
|
|
||||||
|
assert extracted["conclusion"] is not None
|
||||||
|
assert "Solution X" in extracted["conclusion"]
|
||||||
|
|
||||||
|
def test_clean_output_for_user(self):
|
||||||
|
"""Test cleaning agent output for user presentation."""
|
||||||
|
raw_output = """
|
||||||
|
<thinking>
|
||||||
|
Internal reasoning that user shouldn't see.
|
||||||
|
This contains implementation details.
|
||||||
|
</thinking>
|
||||||
|
|
||||||
|
<debug>
|
||||||
|
Debug information: variable X = 42
|
||||||
|
</debug>
|
||||||
|
|
||||||
|
Here's the answer to your question: The solution is to use method Y.
|
||||||
|
|
||||||
|
<internal_notes>
|
||||||
|
Remember to update the documentation later.
|
||||||
|
</internal_notes>
|
||||||
|
|
||||||
|
This approach will give you the best results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
clean_output = remove_xml_tags(
|
||||||
|
raw_output,
|
||||||
|
["thinking", "debug", "internal_notes"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Internal reasoning" not in clean_output
|
||||||
|
assert "Debug information" not in clean_output
|
||||||
|
assert "update the documentation" not in clean_output
|
||||||
|
assert "Here's the answer" in clean_output
|
||||||
|
assert "method Y" in clean_output
|
||||||
|
assert "best results" in clean_output
|
||||||
|
|
||||||
|
def test_preserve_structured_content(self):
|
||||||
|
"""Test preserving structured content while removing tags."""
|
||||||
|
structured_output = """
|
||||||
|
<steps>
|
||||||
|
1. Initialize the system
|
||||||
|
2. Load the configuration
|
||||||
|
3. Process the data
|
||||||
|
4. Generate the report
|
||||||
|
</steps>
|
||||||
|
|
||||||
|
Follow these steps to complete the task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
clean_output = strip_xml_tags_keep_content(structured_output, ["steps"])
|
||||||
|
|
||||||
|
assert "<steps>" not in clean_output
|
||||||
|
assert "</steps>" not in clean_output
|
||||||
|
assert "1. Initialize" in clean_output
|
||||||
|
assert "2. Load" in clean_output
|
||||||
|
assert "3. Process" in clean_output
|
||||||
|
assert "4. Generate" in clean_output
|
||||||
|
assert "Follow these steps" in clean_output
|
||||||
Reference in New Issue
Block a user