diff --git a/examples/llm_generations_example.py b/examples/llm_generations_example.py new file mode 100644 index 000000000..688f1536c --- /dev/null +++ b/examples/llm_generations_example.py @@ -0,0 +1,166 @@ +""" +Example demonstrating the new LLM generations and logprobs functionality. +""" + +from crewai import Agent, Task, Crew, LLM +from crewai.utilities.xml_parser import extract_xml_content, extract_multiple_xml_tags + + +def example_multiple_generations(): + """Example of using multiple generations with an agent.""" + + llm = LLM( + model="gpt-3.5-turbo", + n=3, # Request 3 generations + temperature=0.8, # Higher temperature for variety + return_full_completion=True + ) + + agent = Agent( + role="Creative Writer", + goal="Write engaging content", + backstory="You are a creative writer who generates multiple ideas", + llm=llm, + return_completion_metadata=True + ) + + task = Task( + description="Write a short story opening about a mysterious door", + agent=agent, + expected_output="A compelling story opening" + ) + + result = agent.execute_task(task) + + print("Primary result:") + print(result) + print("\n" + "="*50 + "\n") + + if hasattr(task, 'output') and task.output.completion_metadata: + generations = task.output.get_generations() + if generations: + print(f"Generated {len(generations)} alternatives:") + for i, generation in enumerate(generations, 1): + print(f"\nGeneration {i}:") + print(generation) + print("-" * 30) + + +def example_xml_extraction(): + """Example of extracting structured content from agent output.""" + + agent = Agent( + role="Problem Solver", + goal="Solve problems systematically", + backstory="You think step by step and show your reasoning", + llm=LLM(model="gpt-3.5-turbo") + ) + + task = Task( + description=""" + Solve this problem: How can we reduce energy consumption in an office building? + + Please structure your response with: + - tags for your internal reasoning + - tags for your analysis of the problem + - tags for your proposed solution + """, + agent=agent, + expected_output="A structured solution with reasoning" + ) + + result = agent.execute_task(task) + + print("Full agent output:") + print(result) + print("\n" + "="*50 + "\n") + + thinking = extract_xml_content(result, "thinking") + analysis = extract_xml_content(result, "analysis") + solution = extract_xml_content(result, "solution") + + if thinking: + print("Agent's thinking process:") + print(thinking) + print("\n" + "-"*30 + "\n") + + if analysis: + print("Problem analysis:") + print(analysis) + print("\n" + "-"*30 + "\n") + + if solution: + print("Proposed solution:") + print(solution) + + +def example_logprobs_analysis(): + """Example of accessing log probabilities for analysis.""" + + llm = LLM( + model="gpt-3.5-turbo", + logprobs=5, # Request top 5 log probabilities + top_logprobs=3, # Show top 3 alternatives + return_full_completion=True + ) + + agent = Agent( + role="Decision Analyst", + goal="Make confident decisions", + backstory="You analyze confidence levels in your responses", + llm=llm, + return_completion_metadata=True + ) + + task = Task( + description="Should we invest in renewable energy? Give a yes/no answer with confidence.", + agent=agent, + expected_output="A clear yes/no decision" + ) + + result = agent.execute_task(task) + + print("Decision:") + print(result) + print("\n" + "="*50 + "\n") + + if hasattr(task, 'output') and task.output.completion_metadata: + logprobs = task.output.get_logprobs() + usage = task.output.get_usage_metrics() + + if logprobs: + print("Confidence analysis (log probabilities):") + print(f"Available logprobs data: {len(logprobs)} choices") + + if usage: + print(f"\nToken usage:") + print(f"Prompt tokens: {usage.get('prompt_tokens', 'N/A')}") + print(f"Completion tokens: {usage.get('completion_tokens', 'N/A')}") + print(f"Total tokens: {usage.get('total_tokens', 'N/A')}") + + +if __name__ == "__main__": + print("=== CrewAI LLM Generations and XML Extraction Examples ===\n") + + print("1. Multiple Generations Example:") + print("-" * 40) + try: + example_multiple_generations() + except Exception as e: + print(f"Example requires actual LLM API access: {e}") + + print("\n\n2. XML Content Extraction Example:") + print("-" * 40) + try: + example_xml_extraction() + except Exception as e: + print(f"Example requires actual LLM API access: {e}") + + print("\n\n3. Log Probabilities Analysis Example:") + print("-" * 40) + try: + example_logprobs_analysis() + except Exception as e: + print(f"Example requires actual LLM API access: {e}") + + print("\n=== Examples completed ===") diff --git a/src/crewai/agent.py b/src/crewai/agent.py index c8e34b2e6..103914846 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -88,6 +88,22 @@ class Agent(BaseAgent): llm: Union[str, InstanceOf[BaseLLM], Any] = Field( description="Language model that will run the agent.", default=None ) + llm_n: Optional[int] = Field( + default=None, + description="Number of generations to request from the LLM.", + ) + llm_logprobs: Optional[int] = Field( + default=None, + description="Number of log probabilities to return from the LLM.", + ) + llm_top_logprobs: Optional[int] = Field( + default=None, + description="Number of top log probabilities to return from the LLM.", + ) + return_completion_metadata: bool = Field( + default=False, + description="Whether to return full completion metadata including generations and logprobs.", + ) function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) @@ -179,6 +195,15 @@ class Agent(BaseAgent): ): self.function_calling_llm = create_llm(self.function_calling_llm) + if hasattr(self.llm, 'n') and self.llm_n is not None: + self.llm.n = self.llm_n + if hasattr(self.llm, 'logprobs') and self.llm_logprobs is not None: + self.llm.logprobs = self.llm_logprobs + if hasattr(self.llm, 'top_logprobs') and self.llm_top_logprobs is not None: + self.llm.top_logprobs = self.llm_top_logprobs + if hasattr(self.llm, 'return_full_completion'): + self.llm.return_full_completion = self.return_completion_metadata + if not self.agent_executor: self._setup_agent_executor() diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index 8dfbfaff8..3305b3778 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -92,6 +92,9 @@ class LiteAgentOutput(BaseModel): usage_metrics: Optional[Dict[str, Any]] = Field( description="Token usage metrics for this execution", default=None ) + completion_metadata: Optional[Dict[str, Any]] = Field( + description="Full completion metadata including generations and logprobs", default=None + ) def to_dict(self) -> Dict[str, Any]: """Convert pydantic_output to a dictionary.""" @@ -99,6 +102,40 @@ class LiteAgentOutput(BaseModel): return self.pydantic.model_dump() return {} + def get_generations(self) -> Optional[List[str]]: + """Get all generations from completion metadata.""" + if not self.completion_metadata or "choices" not in self.completion_metadata: + return None + + generations = [] + for choice in self.completion_metadata["choices"]: + if hasattr(choice, "message") and hasattr(choice.message, "content"): + generations.append(choice.message.content or "") + elif isinstance(choice, dict) and "message" in choice: + generations.append(choice["message"].get("content", "")) + + return generations if generations else None + + def get_logprobs(self) -> Optional[List[Dict[str, Any]]]: + """Get log probabilities from completion metadata.""" + if not self.completion_metadata or "choices" not in self.completion_metadata: + return None + + logprobs_list = [] + for choice in self.completion_metadata["choices"]: + if hasattr(choice, "logprobs") and choice.logprobs: + logprobs_list.append(choice.logprobs) + elif isinstance(choice, dict) and "logprobs" in choice: + logprobs_list.append(choice["logprobs"]) + + return logprobs_list if logprobs_list else None + + def get_usage_metrics_from_completion(self) -> Optional[Dict[str, Any]]: + """Get token usage metrics from completion metadata.""" + if not self.completion_metadata: + return None + return self.completion_metadata.get("usage") + def __str__(self) -> str: """String representation of the output.""" if self.pydantic: diff --git a/src/crewai/llm.py b/src/crewai/llm.py index f30ed080f..d2ff4b1b3 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -311,6 +311,7 @@ class LLM(BaseLLM): callbacks: List[Any] = [], reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, stream: bool = False, + return_full_completion: bool = False, **kwargs, ): self.model = model @@ -337,6 +338,7 @@ class LLM(BaseLLM): self.additional_params = kwargs self.is_anthropic = self._is_anthropic_model(model) self.stream = stream + self.return_full_completion = return_full_completion litellm.drop_params = True @@ -419,16 +421,18 @@ class LLM(BaseLLM): params: Dict[str, Any], callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, - ) -> str: + return_full_completion: bool = False, + ) -> Union[str, Dict[str, Any]]: """Handle a streaming response from the LLM. Args: params: Parameters for the completion call callbacks: Optional list of callback functions available_functions: Dict of available functions + return_full_completion: Whether to return full completion object Returns: - str: The complete response text + Union[str, Dict[str, Any]]: The complete response text or full completion object Raises: Exception: If no content is received from the streaming response @@ -626,11 +630,46 @@ class LLM(BaseLLM): self._handle_streaming_callbacks(callbacks, usage_info, last_chunk) # Emit completion event and return response self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) + + if return_full_completion: + accumulated_choices = [] + if last_chunk and hasattr(last_chunk, "choices"): + for choice in last_chunk.choices: + accumulated_choices.append({ + "message": {"content": full_response}, + "finish_reason": getattr(choice, "finish_reason", None), + "index": getattr(choice, "index", 0), + }) + else: + accumulated_choices = [{"message": {"content": full_response}, "finish_reason": "stop", "index": 0}] + + return { + "content": full_response, + "choices": accumulated_choices, + "usage": usage_info, + "model": params.get("model"), + "created": getattr(last_chunk, "created", None) if last_chunk else None, + "id": getattr(last_chunk, "id", None) if last_chunk else None, + "object": "chat.completion", + "system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None, + } + return full_response # --- 9) Handle tool calls if present tool_result = self._handle_tool_call(tool_calls, available_functions) if tool_result is not None: + if return_full_completion: + return { + "content": tool_result, + "choices": [{"message": {"content": tool_result}}], + "usage": usage_info, + "model": params.get("model"), + "created": getattr(last_chunk, "created", None) if last_chunk else None, + "id": getattr(last_chunk, "id", None) if last_chunk else None, + "object": "chat.completion", + "system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None, + } return tool_result # --- 10) Log token usage if available in streaming mode @@ -638,6 +677,30 @@ class LLM(BaseLLM): # --- 11) Emit completion event and return response self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) + + if return_full_completion: + accumulated_choices = [] + if last_chunk and hasattr(last_chunk, "choices"): + for choice in last_chunk.choices: + accumulated_choices.append({ + "message": {"content": full_response}, + "finish_reason": getattr(choice, "finish_reason", None), + "index": getattr(choice, "index", 0), + }) + else: + accumulated_choices = [{"message": {"content": full_response}, "finish_reason": "stop", "index": 0}] + + return { + "content": full_response, + "choices": accumulated_choices, + "usage": usage_info, + "model": params.get("model"), + "created": getattr(last_chunk, "created", None) if last_chunk else None, + "id": getattr(last_chunk, "id", None) if last_chunk else None, + "object": "chat.completion", + "system_fingerprint": getattr(last_chunk, "system_fingerprint", None) if last_chunk else None, + } + return full_response except ContextWindowExceededError as e: @@ -748,16 +811,18 @@ class LLM(BaseLLM): params: Dict[str, Any], callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, - ) -> str: + return_full_completion: bool = False, + ) -> Union[str, Dict[str, Any]]: """Handle a non-streaming response from the LLM. Args: params: Parameters for the completion call callbacks: Optional list of callback functions available_functions: Dict of available functions + return_full_completion: Whether to return full completion object Returns: - str: The response text + Union[str, Dict[str, Any]]: The response text or full completion object """ # --- 1) Make the completion call try: @@ -793,18 +858,51 @@ class LLM(BaseLLM): # --- 4) Check for tool calls tool_calls = getattr(response_message, "tool_calls", []) - # --- 5) If no tool calls or no available functions, return the text response directly + # --- 5) If no tool calls or no available functions, return the response if not tool_calls or not available_functions: self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL) + if return_full_completion: + return { + "content": text_response, + "choices": response.choices, + "usage": getattr(response, "usage", None), + "model": response.model, + "created": getattr(response, "created", None), + "id": getattr(response, "id", None), + "object": getattr(response, "object", "chat.completion"), + "system_fingerprint": getattr(response, "system_fingerprint", None), + } return text_response # --- 6) Handle tool calls if present tool_result = self._handle_tool_call(tool_calls, available_functions) if tool_result is not None: + if return_full_completion: + return { + "content": tool_result, + "choices": response.choices, + "usage": getattr(response, "usage", None), + "model": response.model, + "created": getattr(response, "created", None), + "id": getattr(response, "id", None), + "object": getattr(response, "object", "chat.completion"), + "system_fingerprint": getattr(response, "system_fingerprint", None), + } return tool_result - # --- 7) If tool call handling didn't return a result, emit completion event and return text response + # --- 7) If tool call handling didn't return a result, emit completion event and return response self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL) + if return_full_completion: + return { + "content": text_response, + "choices": response.choices, + "usage": getattr(response, "usage", None), + "model": response.model, + "created": getattr(response, "created", None), + "id": getattr(response, "id", None), + "object": getattr(response, "object", "chat.completion"), + "system_fingerprint": getattr(response, "system_fingerprint", None), + } return text_response def _handle_tool_call( @@ -889,7 +987,8 @@ class LLM(BaseLLM): tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, - ) -> Union[str, Any]: + return_full_completion: Optional[bool] = None, + ) -> Union[str, Dict[str, Any]]: """High-level LLM call method. Args: @@ -903,10 +1002,11 @@ class LLM(BaseLLM): during and after the LLM call. available_functions: Optional dict mapping function names to callables that can be invoked by the LLM. + return_full_completion: Optional override for returning full completion object Returns: - Union[str, Any]: Either a text response from the LLM (str) or - the result of a tool function call (Any). + Union[str, Dict[str, Any]]: Either a text response from the LLM (str) or + the full completion object with generations and metadata. Raises: TypeError: If messages format is invalid @@ -944,17 +1044,20 @@ class LLM(BaseLLM): self.set_callbacks(callbacks) try: - # --- 6) Prepare parameters for the completion call + # --- 6) Determine if we should return full completion + should_return_full = return_full_completion if return_full_completion is not None else self.return_full_completion + + # --- 7) Prepare parameters for the completion call params = self._prepare_completion_params(messages, tools) - # --- 7) Make the completion call and handle response + # --- 8) Make the completion call and handle response if self.stream: return self._handle_streaming_response( - params, callbacks, available_functions + params, callbacks, available_functions, should_return_full ) else: return self._handle_non_streaming_response( - params, callbacks, available_functions + params, callbacks, available_functions, should_return_full ) except LLMContextLengthExceededException: diff --git a/src/crewai/tasks/task_output.py b/src/crewai/tasks/task_output.py index b0e8aecd4..cac2b1879 100644 --- a/src/crewai/tasks/task_output.py +++ b/src/crewai/tasks/task_output.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, model_validator @@ -26,6 +26,9 @@ class TaskOutput(BaseModel): output_format: OutputFormat = Field( description="Output format of the task", default=OutputFormat.RAW ) + completion_metadata: Optional[Dict[str, Any]] = Field( + description="Full completion metadata including generations and logprobs", default=None + ) @model_validator(mode="after") def set_summary(self): @@ -56,6 +59,40 @@ class TaskOutput(BaseModel): output_dict.update(self.pydantic.model_dump()) return output_dict + def get_generations(self) -> Optional[List[str]]: + """Get all generations from completion metadata.""" + if not self.completion_metadata or "choices" not in self.completion_metadata: + return None + + generations = [] + for choice in self.completion_metadata["choices"]: + if hasattr(choice, "message") and hasattr(choice.message, "content"): + generations.append(choice.message.content or "") + elif isinstance(choice, dict) and "message" in choice: + generations.append(choice["message"].get("content", "")) + + return generations if generations else None + + def get_logprobs(self) -> Optional[List[Dict[str, Any]]]: + """Get log probabilities from completion metadata.""" + if not self.completion_metadata or "choices" not in self.completion_metadata: + return None + + logprobs_list = [] + for choice in self.completion_metadata["choices"]: + if hasattr(choice, "logprobs") and choice.logprobs: + logprobs_list.append(choice.logprobs) + elif isinstance(choice, dict) and "logprobs" in choice: + logprobs_list.append(choice["logprobs"]) + + return logprobs_list if logprobs_list else None + + def get_usage_metrics(self) -> Optional[Dict[str, Any]]: + """Get token usage metrics from completion metadata.""" + if not self.completion_metadata: + return None + return self.completion_metadata.get("usage") + def __str__(self) -> str: if self.pydantic: return str(self.pydantic) diff --git a/src/crewai/utilities/agent_utils.py b/src/crewai/utilities/agent_utils.py index c3a38cc9a..93e242234 100644 --- a/src/crewai/utilities/agent_utils.py +++ b/src/crewai/utilities/agent_utils.py @@ -145,12 +145,14 @@ def get_llm_response( messages: List[Dict[str, str]], callbacks: List[Any], printer: Printer, -) -> str: + return_full_completion: bool = False, +) -> Union[str, Dict[str, Any]]: """Call the LLM and return the response, handling any invalid responses.""" try: answer = llm.call( messages, callbacks=callbacks, + return_full_completion=return_full_completion, ) except Exception as e: printer.print( @@ -158,29 +160,42 @@ def get_llm_response( color="red", ) raise e - if not answer: - printer.print( - content="Received None or empty response from LLM call.", - color="red", - ) - raise ValueError("Invalid response from LLM call - None or empty.") - - return answer + + if return_full_completion: + if not answer or (isinstance(answer, dict) and not answer.get("content")): + printer.print( + content="Received None or empty response from LLM call.", + color="red", + ) + raise ValueError("Invalid response from LLM call - None or empty.") + return answer + else: + if not answer: + printer.print( + content="Received None or empty response from LLM call.", + color="red", + ) + raise ValueError("Invalid response from LLM call - None or empty.") + return answer def process_llm_response( - answer: str, use_stop_words: bool + answer: Union[str, Dict[str, Any]], use_stop_words: bool ) -> Union[AgentAction, AgentFinish]: """Process the LLM response and format it into an AgentAction or AgentFinish.""" + text_answer = answer + if isinstance(answer, dict): + text_answer = answer.get("content", "") + if not use_stop_words: try: # Preliminary parsing to check for errors. - format_answer(answer) + format_answer(text_answer) except OutputParserException as e: if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error: - answer = answer.split("Observation:")[0].strip() + text_answer = text_answer.split("Observation:")[0].strip() - return format_answer(answer) + return format_answer(text_answer) def handle_agent_action_core( diff --git a/src/crewai/utilities/xml_parser.py b/src/crewai/utilities/xml_parser.py new file mode 100644 index 000000000..61e05f3e8 --- /dev/null +++ b/src/crewai/utilities/xml_parser.py @@ -0,0 +1,133 @@ +import re +from typing import Dict, List, Optional, Union + + +def extract_xml_content(text: str, tag: str) -> Optional[str]: + """ + Extract content from a specific XML tag. + + Args: + text: The text to search in + tag: The XML tag name to extract (without angle brackets) + + Returns: + The content inside the first occurrence of the tag, or None if not found + """ + pattern = rf'<{re.escape(tag)}>(.*?)' + match = re.search(pattern, text, re.DOTALL) + return match.group(1).strip() if match else None + + +def extract_all_xml_content(text: str, tag: str) -> List[str]: + """ + Extract content from all occurrences of a specific XML tag. + + Args: + text: The text to search in + tag: The XML tag name to extract (without angle brackets) + + Returns: + List of content strings from all occurrences of the tag + """ + pattern = rf'<{re.escape(tag)}>(.*?)' + matches = re.findall(pattern, text, re.DOTALL) + return [match.strip() for match in matches] + + +def extract_multiple_xml_tags(text: str, tags: List[str]) -> Dict[str, Optional[str]]: + """ + Extract content from multiple XML tags. + + Args: + text: The text to search in + tags: List of XML tag names to extract (without angle brackets) + + Returns: + Dictionary mapping tag names to their content (or None if not found) + """ + result = {} + for tag in tags: + result[tag] = extract_xml_content(text, tag) + return result + + +def extract_multiple_xml_tags_all(text: str, tags: List[str]) -> Dict[str, List[str]]: + """ + Extract content from all occurrences of multiple XML tags. + + Args: + text: The text to search in + tags: List of XML tag names to extract (without angle brackets) + + Returns: + Dictionary mapping tag names to lists of their content + """ + result = {} + for tag in tags: + result[tag] = extract_all_xml_content(text, tag) + return result + + +def extract_xml_with_attributes(text: str, tag: str) -> List[Dict[str, Union[str, Dict[str, str]]]]: + """ + Extract XML tags with their attributes and content. + + Args: + text: The text to search in + tag: The XML tag name to extract (without angle brackets) + + Returns: + List of dictionaries containing 'content' and 'attributes' for each occurrence + """ + pattern = rf'<{re.escape(tag)}([^>]*)>(.*?)' + matches = re.findall(pattern, text, re.DOTALL) + + result = [] + for attrs_str, content in matches: + attributes = {} + if attrs_str.strip(): + attr_pattern = r'(\w+)=["\']([^"\']*)["\']' + attributes = dict(re.findall(attr_pattern, attrs_str)) + + result.append({ + 'content': content.strip(), + 'attributes': attributes + }) + + return result + + +def remove_xml_tags(text: str, tags: List[str]) -> str: + """ + Remove specific XML tags and their content from text. + + Args: + text: The text to process + tags: List of XML tag names to remove (without angle brackets) + + Returns: + Text with the specified XML tags and their content removed + """ + result = text + for tag in tags: + pattern = rf'<{re.escape(tag)}[^>]*>.*?' + result = re.sub(pattern, '', result, flags=re.DOTALL) + return result.strip() + + +def strip_xml_tags_keep_content(text: str, tags: List[str]) -> str: + """ + Remove specific XML tags but keep their content. + + Args: + text: The text to process + tags: List of XML tag names to strip (without angle brackets) + + Returns: + Text with the specified XML tags removed but content preserved + """ + result = text + for tag in tags: + pattern = rf'<{re.escape(tag)}[^>]*>(.*?)' + result = re.sub(pattern, r'\1', result, flags=re.DOTALL) + return result.strip() diff --git a/tests/test_integration_llm_features.py b/tests/test_integration_llm_features.py new file mode 100644 index 000000000..251c34fb6 --- /dev/null +++ b/tests/test_integration_llm_features.py @@ -0,0 +1,158 @@ +import pytest +from unittest.mock import Mock, patch +from crewai import Agent, Task, Crew, LLM +from crewai.lite_agent import LiteAgent +from crewai.utilities.xml_parser import extract_xml_content + + +class TestIntegrationLLMFeatures: + """Integration tests for LLM features with agents and tasks.""" + + @patch('crewai.llm.litellm.completion') + def test_agent_with_multiple_generations(self, mock_completion): + """Test agent execution with multiple generations.""" + mock_response = Mock() + mock_response.choices = [ + Mock(message=Mock(content="Generation 1")), + Mock(message=Mock(content="Generation 2")), + Mock(message=Mock(content="Generation 3")), + ] + mock_response.usage = {"prompt_tokens": 20, "completion_tokens": 30} + mock_response.model = "gpt-3.5-turbo" + mock_response.created = 1234567890 + mock_response.id = "test-id" + mock_response.object = "chat.completion" + mock_response.system_fingerprint = "test-fingerprint" + mock_completion.return_value = mock_response + + llm = LLM(model="gpt-3.5-turbo", n=3, return_full_completion=True) + agent = Agent( + role="writer", + goal="write content", + backstory="You are a writer", + llm=llm, + return_completion_metadata=True, + ) + + task = Task( + description="Write a short story", + agent=agent, + expected_output="A short story", + ) + + with patch.object(agent, 'agent_executor') as mock_executor: + mock_executor.invoke.return_value = {"output": "Generation 1"} + + result = agent.execute_task(task) + assert result == "Generation 1" + + @patch('crewai.llm.litellm.completion') + def test_lite_agent_with_xml_extraction(self, mock_completion): + """Test LiteAgent with XML content extraction.""" + response_with_xml = """ + + I need to analyze this problem step by step. + First, I'll consider the requirements. + + + Based on my analysis, here's the solution: The answer is 42. + """ + + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content=response_with_xml))] + mock_response.usage = {"prompt_tokens": 15, "completion_tokens": 25} + mock_response.model = "gpt-3.5-turbo" + mock_response.created = 1234567890 + mock_response.id = "test-id" + mock_response.object = "chat.completion" + mock_response.system_fingerprint = "test-fingerprint" + mock_completion.return_value = mock_response + + lite_agent = LiteAgent( + role="analyst", + goal="analyze problems", + backstory="You are an analyst", + llm=LLM(model="gpt-3.5-turbo", return_full_completion=True), + ) + + with patch.object(lite_agent, '_invoke_loop') as mock_invoke: + mock_invoke.return_value = response_with_xml + + result = lite_agent.kickoff("Analyze this problem") + + thinking_content = extract_xml_content(result.raw, "thinking") + assert thinking_content is not None + assert "step by step" in thinking_content + assert "requirements" in thinking_content + + def test_xml_parser_with_complex_agent_output(self): + """Test XML parser with complex agent output containing multiple tags.""" + complex_output = """ + + This is a complex problem that requires careful analysis. + I need to break it down into steps. + + + + Step 1: Understand the requirements + Step 2: Analyze the constraints + Step 3: Develop a solution + + + + The best approach is to use a systematic methodology. + + + Final answer: Use the systematic approach outlined above. + """ + + thinking = extract_xml_content(complex_output, "thinking") + reasoning = extract_xml_content(complex_output, "reasoning") + conclusion = extract_xml_content(complex_output, "conclusion") + + assert thinking is not None + assert "complex problem" in thinking + assert reasoning is not None + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert conclusion is not None + assert "systematic methodology" in conclusion + + @patch('crewai.llm.litellm.completion') + def test_crew_with_llm_parameters(self, mock_completion): + """Test crew execution with LLM parameters.""" + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(content="Test response"))] + mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5} + mock_response.model = "gpt-3.5-turbo" + mock_response.created = 1234567890 + mock_response.id = "test-id" + mock_response.object = "chat.completion" + mock_response.system_fingerprint = "test-fingerprint" + mock_completion.return_value = mock_response + + agent = Agent( + role="analyst", + goal="analyze data", + backstory="You are an analyst", + llm_n=2, + llm_logprobs=5, + return_completion_metadata=True, + ) + + task = Task( + description="Analyze the data", + agent=agent, + expected_output="Analysis results", + ) + + crew = Crew(agents=[agent], tasks=[task]) + + with patch.object(crew, 'kickoff') as mock_kickoff: + mock_output = Mock() + mock_output.tasks_output = [Mock(completion_metadata={"choices": mock_response.choices})] + mock_kickoff.return_value = mock_output + + result = crew.kickoff() + assert result is not None diff --git a/tests/test_llm_generations_logprobs.py b/tests/test_llm_generations_logprobs.py new file mode 100644 index 000000000..d6f52d19a --- /dev/null +++ b/tests/test_llm_generations_logprobs.py @@ -0,0 +1,227 @@ +import pytest +from unittest.mock import Mock, patch +from crewai import Agent, Task, LLM +from crewai.tasks.task_output import TaskOutput +from crewai.lite_agent import LiteAgent, LiteAgentOutput +from crewai.utilities.xml_parser import ( + extract_xml_content, + extract_all_xml_content, + extract_multiple_xml_tags, + extract_multiple_xml_tags_all, + extract_xml_with_attributes, + remove_xml_tags, + strip_xml_tags_keep_content, +) + + +class TestLLMGenerationsLogprobs: + """Test suite for LLM generations and logprobs functionality.""" + + def test_llm_with_n_parameter(self): + """Test that LLM accepts n parameter for multiple generations.""" + llm = LLM(model="gpt-3.5-turbo", n=3) + assert llm.n == 3 + + def test_llm_with_logprobs_parameter(self): + """Test that LLM accepts logprobs parameter.""" + llm = LLM(model="gpt-3.5-turbo", logprobs=5) + assert llm.logprobs == 5 + + def test_llm_with_return_full_completion(self): + """Test that LLM accepts return_full_completion parameter.""" + llm = LLM(model="gpt-3.5-turbo", return_full_completion=True) + assert llm.return_full_completion is True + + def test_agent_with_llm_parameters(self): + """Test that Agent accepts LLM generation parameters.""" + agent = Agent( + role="test", + goal="test", + backstory="test", + llm_n=3, + llm_logprobs=5, + llm_top_logprobs=3, + return_completion_metadata=True, + ) + assert agent.llm_n == 3 + assert agent.llm_logprobs == 5 + assert agent.llm_top_logprobs == 3 + assert agent.return_completion_metadata is True + + @patch('crewai.llm.litellm.completion') + def test_llm_call_returns_full_completion(self, mock_completion): + """Test that LLM.call can return full completion object.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Test response" + mock_response.usage = {"prompt_tokens": 10, "completion_tokens": 5} + mock_response.model = "gpt-3.5-turbo" + mock_response.created = 1234567890 + mock_response.id = "test-id" + mock_response.object = "chat.completion" + mock_response.system_fingerprint = "test-fingerprint" + mock_completion.return_value = mock_response + + llm = LLM(model="gpt-3.5-turbo", return_full_completion=True) + result = llm.call("Test message") + + assert isinstance(result, dict) + assert result["content"] == "Test response" + assert "choices" in result + assert "usage" in result + assert result["model"] == "gpt-3.5-turbo" + + def test_task_output_completion_metadata(self): + """Test TaskOutput with completion metadata.""" + mock_choices = [ + Mock(message=Mock(content="Generation 1")), + Mock(message=Mock(content="Generation 2")), + ] + mock_usage = {"prompt_tokens": 10, "completion_tokens": 15} + + completion_metadata = { + "choices": mock_choices, + "usage": mock_usage, + "model": "gpt-3.5-turbo", + } + + task_output = TaskOutput( + description="Test task", + raw="Generation 1", + agent="test-agent", + completion_metadata=completion_metadata, + ) + + generations = task_output.get_generations() + assert generations == ["Generation 1", "Generation 2"] + + usage = task_output.get_usage_metrics() + assert usage == mock_usage + + def test_lite_agent_output_completion_metadata(self): + """Test LiteAgentOutput with completion metadata.""" + mock_choices = [ + Mock(message=Mock(content="Generation 1")), + Mock(message=Mock(content="Generation 2")), + ] + mock_usage = {"prompt_tokens": 10, "completion_tokens": 15} + + completion_metadata = { + "choices": mock_choices, + "usage": mock_usage, + "model": "gpt-3.5-turbo", + } + + output = LiteAgentOutput( + raw="Generation 1", + agent_role="test-agent", + completion_metadata=completion_metadata, + ) + + generations = output.get_generations() + assert generations == ["Generation 1", "Generation 2"] + + usage = output.get_usage_metrics_from_completion() + assert usage == mock_usage + + +class TestXMLParser: + """Test suite for XML parsing functionality.""" + + def test_extract_xml_content_basic(self): + """Test basic XML content extraction.""" + text = "Some text This is my thought more text" + result = extract_xml_content(text, "thinking") + assert result == "This is my thought" + + def test_extract_xml_content_not_found(self): + """Test XML content extraction when tag not found.""" + text = "Some text without the tag" + result = extract_xml_content(text, "thinking") + assert result is None + + def test_extract_xml_content_multiline(self): + """Test XML content extraction with multiline content.""" + text = """Some text + + This is a multiline + thought process + + more text""" + result = extract_xml_content(text, "thinking") + assert "multiline" in result + assert "thought process" in result + + def test_extract_all_xml_content(self): + """Test extracting all occurrences of XML content.""" + text = """ + First thought + Some text + Second thought + """ + result = extract_all_xml_content(text, "thinking") + assert len(result) == 2 + assert result[0] == "First thought" + assert result[1] == "Second thought" + + def test_extract_multiple_xml_tags(self): + """Test extracting multiple different XML tags.""" + text = """ + My thoughts + My reasoning + My conclusion + """ + result = extract_multiple_xml_tags(text, ["thinking", "reasoning", "conclusion"]) + assert result["thinking"] == "My thoughts" + assert result["reasoning"] == "My reasoning" + assert result["conclusion"] == "My conclusion" + + def test_extract_multiple_xml_tags_all(self): + """Test extracting all occurrences of multiple XML tags.""" + text = """ + First thought + First reasoning + Second thought + """ + result = extract_multiple_xml_tags_all(text, ["thinking", "reasoning"]) + assert len(result["thinking"]) == 2 + assert len(result["reasoning"]) == 1 + assert result["thinking"][0] == "First thought" + assert result["thinking"][1] == "Second thought" + + def test_extract_xml_with_attributes(self): + """Test extracting XML with attributes.""" + text = 'Complex thought' + result = extract_xml_with_attributes(text, "thinking") + assert len(result) == 1 + assert result[0]["content"] == "Complex thought" + assert result[0]["attributes"]["type"] == "deep" + assert result[0]["attributes"]["level"] == "2" + + def test_remove_xml_tags(self): + """Test removing XML tags and their content.""" + text = "Keep this Remove this and this" + result = remove_xml_tags(text, ["thinking"]) + assert result == "Keep this and this" + + def test_strip_xml_tags_keep_content(self): + """Test stripping XML tags but keeping content.""" + text = "Keep this Keep this too and this" + result = strip_xml_tags_keep_content(text, ["thinking"]) + assert result == "Keep this Keep this too and this" + + def test_nested_xml_tags(self): + """Test handling of nested XML tags.""" + text = "Before nested content after" + result = extract_xml_content(text, "outer") + assert "Before" in result + assert "nested content" in result + assert "after" in result + + def test_xml_with_special_characters(self): + """Test XML parsing with special characters.""" + text = "Content with & < > \" ' characters" + result = extract_xml_content(text, "thinking") + assert "&" in result + assert "<" in result + assert ">" in result diff --git a/tests/test_xml_parser_examples.py b/tests/test_xml_parser_examples.py new file mode 100644 index 000000000..24194d978 --- /dev/null +++ b/tests/test_xml_parser_examples.py @@ -0,0 +1,162 @@ +import pytest +from crewai.utilities.xml_parser import ( + extract_xml_content, + extract_all_xml_content, + extract_multiple_xml_tags, + remove_xml_tags, + strip_xml_tags_keep_content, +) + + +class TestXMLParserExamples: + """Test XML parser with realistic agent output examples.""" + + def test_agent_thinking_extraction(self): + """Test extracting thinking content from agent output.""" + agent_output = """ + I need to solve this problem step by step. + + + Let me break this down: + 1. First, I need to understand the requirements + 2. Then, I'll analyze the constraints + 3. Finally, I'll propose a solution + + The key insight is that we need to balance efficiency with accuracy. + + + Based on my analysis, here's my recommendation: Use approach A. + """ + + thinking = extract_xml_content(agent_output, "thinking") + assert thinking is not None + assert "break this down" in thinking + assert "requirements" in thinking + assert "constraints" in thinking + assert "efficiency with accuracy" in thinking + + def test_multiple_reasoning_tags(self): + """Test extracting multiple reasoning sections.""" + agent_output = """ + + Initial analysis shows three possible approaches. + + + Let me explore each option: + + + Option A: Fast but less accurate + Option B: Slow but very accurate + Option C: Balanced approach + + + My final recommendation is Option C. + """ + + reasoning_sections = extract_all_xml_content(agent_output, "reasoning") + assert len(reasoning_sections) == 2 + assert "three possible approaches" in reasoning_sections[0] + assert "Option A" in reasoning_sections[1] + assert "Option B" in reasoning_sections[1] + assert "Option C" in reasoning_sections[1] + + def test_complex_agent_workflow(self): + """Test complex agent output with multiple tag types.""" + complex_output = """ + + This is a complex problem requiring systematic analysis. + I need to consider multiple factors. + + + + Factor 1: Performance requirements + Factor 2: Cost constraints + Factor 3: Time limitations + + + + Given the analysis above, I believe we should prioritize performance + while keeping costs reasonable. Time is less critical in this case. + + + + Recommend Solution X with performance optimizations. + + + Final answer: Implement Solution X with the following optimizations... + """ + + extracted = extract_multiple_xml_tags( + complex_output, + ["thinking", "analysis", "reasoning", "conclusion"] + ) + + assert extracted["thinking"] is not None + assert "systematic analysis" in extracted["thinking"] + + assert extracted["analysis"] is not None + assert "Factor 1" in extracted["analysis"] + assert "Factor 2" in extracted["analysis"] + assert "Factor 3" in extracted["analysis"] + + assert extracted["reasoning"] is not None + assert "prioritize performance" in extracted["reasoning"] + + assert extracted["conclusion"] is not None + assert "Solution X" in extracted["conclusion"] + + def test_clean_output_for_user(self): + """Test cleaning agent output for user presentation.""" + raw_output = """ + + Internal reasoning that user shouldn't see. + This contains implementation details. + + + + Debug information: variable X = 42 + + + Here's the answer to your question: The solution is to use method Y. + + + Remember to update the documentation later. + + + This approach will give you the best results. + """ + + clean_output = remove_xml_tags( + raw_output, + ["thinking", "debug", "internal_notes"] + ) + + assert "Internal reasoning" not in clean_output + assert "Debug information" not in clean_output + assert "update the documentation" not in clean_output + assert "Here's the answer" in clean_output + assert "method Y" in clean_output + assert "best results" in clean_output + + def test_preserve_structured_content(self): + """Test preserving structured content while removing tags.""" + structured_output = """ + + 1. Initialize the system + 2. Load the configuration + 3. Process the data + 4. Generate the report + + + Follow these steps to complete the task. + """ + + clean_output = strip_xml_tags_keep_content(structured_output, ["steps"]) + + assert "" not in clean_output + assert "" not in clean_output + assert "1. Initialize" in clean_output + assert "2. Load" in clean_output + assert "3. Process" in clean_output + assert "4. Generate" in clean_output + assert "Follow these steps" in clean_output