From d19d7b01ec6f364233abb07d24079be48e648c7c Mon Sep 17 00:00:00 2001 From: Daniel Barreto Date: Wed, 29 Jan 2025 12:11:48 -0300 Subject: [PATCH 1/6] docs: add a "Human Input" row to the Task Attributes table (#1999) --- docs/concepts/tasks.mdx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/concepts/tasks.mdx b/docs/concepts/tasks.mdx index 6ffd95e19..de7378879 100644 --- a/docs/concepts/tasks.mdx +++ b/docs/concepts/tasks.mdx @@ -33,11 +33,12 @@ crew = Crew( | :------------------------------- | :---------------- | :---------------------------- | :------------------------------------------------------------------------------------------------------------------- | | **Description** | `description` | `str` | A clear, concise statement of what the task entails. | | **Expected Output** | `expected_output` | `str` | A detailed description of what the task's completion looks like. | -| **Name** _(optional)_ | `name` | `Optional[str]` | A name identifier for the task. | -| **Agent** _(optional)_ | `agent` | `Optional[BaseAgent]` | The agent responsible for executing the task. | -| **Tools** _(optional)_ | `tools` | `List[BaseTool]` | The tools/resources the agent is limited to use for this task. | +| **Name** _(optional)_ | `name` | `Optional[str]` | A name identifier for the task. | +| **Agent** _(optional)_ | `agent` | `Optional[BaseAgent]` | The agent responsible for executing the task. | +| **Tools** _(optional)_ | `tools` | `List[BaseTool]` | The tools/resources the agent is limited to use for this task. | | **Context** _(optional)_ | `context` | `Optional[List["Task"]]` | Other tasks whose outputs will be used as context for this task. | | **Async Execution** _(optional)_ | `async_execution` | `Optional[bool]` | Whether the task should be executed asynchronously. Defaults to False. | +| **Human Input** _(optional)_ | `human_input` | `Optional[bool]` | Whether the task should have a human review the final answer of the agent. Defaults to False. | | **Config** _(optional)_ | `config` | `Optional[Dict[str, Any]]` | Task-specific configuration parameters. | | **Output File** _(optional)_ | `output_file` | `Optional[str]` | File path for storing the task output. | | **Output JSON** _(optional)_ | `output_json` | `Optional[Type[BaseModel]]` | A Pydantic model to structure the JSON output. | From 2709a9205a042e2baabd7d2f97f40365337b8c30 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:24:50 -0800 Subject: [PATCH 2/6] =?UTF-8?q?fixes=20interpolation=20issues=20when=20inp?= =?UTF-8?q?uts=20are=20type=20dict,list=20specificall=E2=80=A6=20(#1992)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixes interpolation issues when inputs are type dict,list specifically when defined on expected_output * improvements with type hints, doc fixes and rm print statements * more tests * test passing --------- Co-authored-by: Brandon Hancock --- src/crewai/task.py | 48 ++++--- tests/task_test.py | 317 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 350 insertions(+), 15 deletions(-) diff --git a/src/crewai/task.py b/src/crewai/task.py index 030bce779..cbf651f9b 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -431,7 +431,9 @@ class Task(BaseModel): content = ( json_output if json_output - else pydantic_output.model_dump_json() if pydantic_output else result + else pydantic_output.model_dump_json() + if pydantic_output + else result ) self._save_file(content) @@ -452,7 +454,7 @@ class Task(BaseModel): return "\n".join(tasks_slices) def interpolate_inputs_and_add_conversation_history( - self, inputs: Dict[str, Union[str, int, float]] + self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] ) -> None: """Interpolate inputs into the task description, expected output, and output file path. Add conversation history if present. @@ -524,7 +526,9 @@ class Task(BaseModel): ) def interpolate_only( - self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]] + self, + input_string: Optional[str], + inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]], ) -> str: """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched. @@ -532,17 +536,39 @@ class Task(BaseModel): input_string: The string containing template variables to interpolate. Can be None or empty, in which case an empty string is returned. inputs: Dictionary mapping template variables to their values. - Supported value types are strings, integers, and floats. - If input_string is empty or has no placeholders, inputs can be empty. + Supported value types are strings, integers, floats, and dicts/lists + containing only these types and other nested dicts/lists. Returns: The interpolated string with all template variables replaced with their values. Empty string if input_string is None or empty. Raises: - ValueError: If a required template variable is missing from inputs. - KeyError: If a template variable is not found in the inputs dictionary. + ValueError: If a value contains unsupported types """ + + # Validation function for recursive type checking + def validate_type(value: Any) -> None: + if value is None: + return + if isinstance(value, (str, int, float, bool)): + return + if isinstance(value, (dict, list)): + for item in value.values() if isinstance(value, dict) else value: + validate_type(item) + return + raise ValueError( + f"Unsupported type {type(value).__name__} in inputs. " + "Only str, int, float, bool, dict, and list are allowed." + ) + + # Validate all input values + for key, value in inputs.items(): + try: + validate_type(value) + except ValueError as e: + raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e + if input_string is None or not input_string: return "" if "{" not in input_string and "}" not in input_string: @@ -551,15 +577,7 @@ class Task(BaseModel): raise ValueError( "Inputs dictionary cannot be empty when interpolating variables" ) - try: - # Validate input types - for key, value in inputs.items(): - if not isinstance(value, (str, int, float)): - raise ValueError( - f"Value for key '{key}' must be a string, integer, or float, got {type(value).__name__}" - ) - escaped_string = input_string.replace("{", "{{").replace("}", "}}") for key in inputs.keys(): diff --git a/tests/task_test.py b/tests/task_test.py index 59e58dcca..5ffaf2534 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -779,6 +779,43 @@ def test_interpolate_only(): assert result == no_placeholders +def test_interpolate_only_with_dict_inside_expected_output(): + """Test the interpolate_only method for various scenarios including JSON structure preservation.""" + task = Task( + description="Unused in this test", + expected_output="Unused in this test: {questions}", + ) + + json_string = '{"questions": {"main_question": "What is the user\'s name?", "secondary_question": "What is the user\'s age?"}}' + result = task.interpolate_only( + input_string=json_string, + inputs={ + "questions": { + "main_question": "What is the user's name?", + "secondary_question": "What is the user's age?", + } + }, + ) + assert '"main_question": "What is the user\'s name?"' in result + assert '"secondary_question": "What is the user\'s age?"' in result + assert result == json_string + + normal_string = "Hello {name}, welcome to {place}!" + result = task.interpolate_only( + input_string=normal_string, inputs={"name": "John", "place": "CrewAI"} + ) + assert result == "Hello John, welcome to CrewAI!" + + result = task.interpolate_only(input_string="", inputs={"unused": "value"}) + assert result == "" + + no_placeholders = "Hello, this is a test" + result = task.interpolate_only( + input_string=no_placeholders, inputs={"unused": "value"} + ) + assert result == no_placeholders + + def test_task_output_str_with_pydantic(): from crewai.tasks.output_format import OutputFormat @@ -966,3 +1003,283 @@ def test_task_execution_times(): assert task.start_time is not None assert task.end_time is not None assert task.execution_duration == (task.end_time - task.start_time).total_seconds() + + +def test_interpolate_with_list_of_strings(): + task = Task( + description="Test list interpolation", + expected_output="List: {items}", + ) + + # Test simple list of strings + input_str = "Available items: {items}" + inputs = {"items": ["apple", "banana", "cherry"]} + result = task.interpolate_only(input_str, inputs) + assert result == f"Available items: {inputs['items']}" + + # Test empty list + empty_list_input = {"items": []} + result = task.interpolate_only(input_str, empty_list_input) + assert result == "Available items: []" + + +def test_interpolate_with_list_of_dicts(): + task = Task( + description="Test list of dicts interpolation", + expected_output="People: {people}", + ) + + input_data = { + "people": [ + {"name": "Alice", "age": 30, "skills": ["Python", "AI"]}, + {"name": "Bob", "age": 25, "skills": ["Java", "Cloud"]}, + ] + } + result = task.interpolate_only("{people}", input_data) + + parsed_result = eval(result) + assert isinstance(parsed_result, list) + assert len(parsed_result) == 2 + assert parsed_result[0]["name"] == "Alice" + assert parsed_result[0]["age"] == 30 + assert parsed_result[0]["skills"] == ["Python", "AI"] + assert parsed_result[1]["name"] == "Bob" + assert parsed_result[1]["age"] == 25 + assert parsed_result[1]["skills"] == ["Java", "Cloud"] + + +def test_interpolate_with_nested_structures(): + task = Task( + description="Test nested structures", + expected_output="Company: {company}", + ) + + input_data = { + "company": { + "name": "TechCorp", + "departments": [ + { + "name": "Engineering", + "employees": 50, + "tools": ["Git", "Docker", "Kubernetes"], + }, + {"name": "Sales", "employees": 20, "regions": {"north": 5, "south": 3}}, + ], + } + } + result = task.interpolate_only("{company}", input_data) + parsed = eval(result) + + assert parsed["name"] == "TechCorp" + assert len(parsed["departments"]) == 2 + assert parsed["departments"][0]["tools"] == ["Git", "Docker", "Kubernetes"] + assert parsed["departments"][1]["regions"]["north"] == 5 + + +def test_interpolate_with_special_characters(): + task = Task( + description="Test special characters in dicts", + expected_output="Data: {special_data}", + ) + + input_data = { + "special_data": { + "quotes": """This has "double" and 'single' quotes""", + "unicode": "文字化けテスト", + "symbols": "!@#$%^&*()", + "empty": "", + } + } + result = task.interpolate_only("{special_data}", input_data) + parsed = eval(result) + + assert parsed["quotes"] == """This has "double" and 'single' quotes""" + assert parsed["unicode"] == "文字化けテスト" + assert parsed["symbols"] == "!@#$%^&*()" + assert parsed["empty"] == "" + + +def test_interpolate_mixed_types(): + task = Task( + description="Test mixed type interpolation", + expected_output="Mixed: {data}", + ) + + input_data = { + "data": { + "name": "Test Dataset", + "samples": 1000, + "features": ["age", "income", "location"], + "metadata": { + "source": "public", + "validated": True, + "tags": ["demo", "test", "temp"], + }, + } + } + result = task.interpolate_only("{data}", input_data) + parsed = eval(result) + + assert parsed["name"] == "Test Dataset" + assert parsed["samples"] == 1000 + assert parsed["metadata"]["tags"] == ["demo", "test", "temp"] + + +def test_interpolate_complex_combination(): + task = Task( + description="Test complex combination", + expected_output="Report: {report}", + ) + + input_data = { + "report": [ + { + "month": "January", + "metrics": {"sales": 15000, "expenses": 8000, "profit": 7000}, + "top_products": ["Product A", "Product B"], + }, + { + "month": "February", + "metrics": {"sales": 18000, "expenses": 8500, "profit": 9500}, + "top_products": ["Product C", "Product D"], + }, + ] + } + result = task.interpolate_only("{report}", input_data) + parsed = eval(result) + + assert len(parsed) == 2 + assert parsed[0]["month"] == "January" + assert parsed[1]["metrics"]["profit"] == 9500 + assert "Product D" in parsed[1]["top_products"] + + +def test_interpolate_invalid_type_validation(): + task = Task( + description="Test invalid type validation", + expected_output="Should never reach here", + ) + + # Test with invalid top-level type + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": set()}) # type: ignore we are purposely testing this failure + + assert "Unsupported type set" in str(excinfo.value) + + # Test with invalid nested type + invalid_nested = { + "profile": { + "name": "John", + "age": 30, + "tags": {"a", "b", "c"}, # Set is invalid + } + } + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": invalid_nested}) + assert "Unsupported type set" in str(excinfo.value) + + +def test_interpolate_custom_object_validation(): + task = Task( + description="Test custom object rejection", + expected_output="Should never reach here", + ) + + class CustomObject: + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + # Test with custom object at top level + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{obj}", {"obj": CustomObject(5)}) # type: ignore we are purposely testing this failure + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with nested custom object in dictionary + with pytest.raises(ValueError) as excinfo: + task.interpolate_only( + "{data}", {"data": {"valid": 1, "invalid": CustomObject(5)}} + ) + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with nested custom object in list + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": [1, "valid", CustomObject(5)]}) + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with deeply nested custom object + with pytest.raises(ValueError) as excinfo: + task.interpolate_only( + "{data}", {"data": {"level1": {"level2": [{"level3": CustomObject(5)}]}}} + ) + assert "Unsupported type CustomObject" in str(excinfo.value) + + +def test_interpolate_valid_complex_types(): + task = Task( + description="Test valid complex types", + expected_output="Validation should pass", + ) + + # Valid complex structure + valid_data = { + "name": "Valid Dataset", + "stats": { + "count": 1000, + "distribution": [0.2, 0.3, 0.5], + "features": ["age", "income"], + "nested": {"deep": [1, 2, 3], "deeper": {"a": 1, "b": 2.5}}, + }, + } + + # Should not raise any errors + result = task.interpolate_only("{data}", {"data": valid_data}) + parsed = eval(result) + assert parsed["name"] == "Valid Dataset" + assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5 + + +def test_interpolate_edge_cases(): + task = Task( + description="Test edge cases", + expected_output="Edge case handling", + ) + + # Test empty dict and list + assert task.interpolate_only("{}", {"data": {}}) == "{}" + assert task.interpolate_only("[]", {"data": []}) == "[]" + + # Test numeric types + assert task.interpolate_only("{num}", {"num": 42}) == "42" + assert task.interpolate_only("{num}", {"num": 3.14}) == "3.14" + + # Test boolean values (valid JSON types) + assert task.interpolate_only("{flag}", {"flag": True}) == "True" + assert task.interpolate_only("{flag}", {"flag": False}) == "False" + + +def test_interpolate_valid_types(): + task = Task( + description="Test valid types including null and boolean", + expected_output="Should pass validation", + ) + + # Test with boolean and null values (valid JSON types) + valid_data = { + "name": "Test", + "active": True, + "deleted": False, + "optional": None, + "nested": {"flag": True, "empty": None}, + } + + result = task.interpolate_only("{data}", {"data": valid_data}) + parsed = eval(result) + + assert parsed["active"] is True + assert parsed["deleted"] is False + assert parsed["optional"] is None + assert parsed["nested"]["flag"] is True + assert parsed["nested"]["empty"] is None From 7bed63a6931cc9e5135b3d19865d458036492c61 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:11:14 -0500 Subject: [PATCH 3/6] Bugfix/fix broken training (#1993) * Fixing training while refactoring code * improve prompts * make sure to raise an error when missing training data * Drop comment * fix failing tests * add clear * drop bad code * fix failing test * Fix type issues pointed out by lorenze * simplify training --- .../base_agent_executor_mixin.py | 27 +- src/crewai/agents/crew_agent_executor.py | 243 ++++++++++-------- src/crewai/crew.py | 31 ++- .../utilities/evaluators/task_evaluator.py | 31 ++- src/crewai/utilities/training_handler.py | 9 + .../evaluators/test_task_evaluator.py | 6 +- 6 files changed, 210 insertions(+), 137 deletions(-) diff --git a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py index bcc585731..924cef71c 100644 --- a/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/src/crewai/agents/agent_builder/base_agent_executor_mixin.py @@ -95,18 +95,29 @@ class CrewAgentExecutorMixin: pass def _ask_human_input(self, final_answer: str) -> str: - """Prompt human input for final decision making.""" + """Prompt human input with mode-appropriate messaging.""" self._printer.print( content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m" ) - self._printer.print( - content=( + # Training mode prompt (single iteration) + if self.crew and getattr(self.crew, "_train", False): + prompt = ( "\n\n=====\n" - "## Please provide feedback on the Final Result and the Agent's actions. " - "Respond with 'looks good' or a similar phrase when you're satisfied.\n" + "## TRAINING MODE: Provide feedback to improve the agent's performance.\n" + "This will be used to train better versions of the agent.\n" + "Please provide detailed feedback about the result quality and reasoning process.\n" "=====\n" - ), - color="bold_yellow", - ) + ) + # Regular human-in-the-loop prompt (multiple iterations) + else: + prompt = ( + "\n\n=====\n" + "## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n" + "Respond with 'looks good' to accept or provide specific improvement requests.\n" + "You can provide multiple rounds of feedback until satisfied.\n" + "=====\n" + ) + + self._printer.print(content=prompt, color="bold_yellow") return input() diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index b9797193c..b144872b1 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -100,6 +100,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): try: formatted_answer = self._invoke_loop() + except AssertionError: + self._printer.print( + content="Agent failed to reach a final answer. This is likely a bug - please report it.", + color="red", + ) + raise except Exception as e: if e.__class__.__module__.startswith("litellm"): # Do not retry on litellm errors @@ -115,7 +121,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self._create_long_term_memory(formatted_answer) return {"output": formatted_answer.output} - def _invoke_loop(self): + def _invoke_loop(self) -> AgentFinish: """ Main loop to invoke the agent's thought process until it reaches a conclusion or the maximum number of iterations is reached. @@ -161,6 +167,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): finally: self.iterations += 1 + # During the invoke loop, formatted_answer alternates between AgentAction + # (when the agent is using tools) and eventually becomes AgentFinish + # (when the agent reaches a final answer). This assertion confirms we've + # reached a final answer and helps type checking understand this transition. + assert isinstance(formatted_answer, AgentFinish) self._show_logs(formatted_answer) return formatted_answer @@ -292,8 +303,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self._printer.print( content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m" ) + description = ( + getattr(self.task, "description") if self.task else "Not Found" + ) self._printer.print( - content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m" + content=f"\033[95m## Task:\033[00m \033[92m{description}\033[00m" ) def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]): @@ -418,58 +432,50 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ) def _handle_crew_training_output( - self, result: AgentFinish, human_feedback: str | None = None + self, result: AgentFinish, human_feedback: Optional[str] = None ) -> None: - """Function to handle the process of the training data.""" + """Handle the process of saving training data.""" agent_id = str(self.agent.id) # type: ignore + train_iteration = ( + getattr(self.crew, "_train_iteration", None) if self.crew else None + ) + + if train_iteration is None or not isinstance(train_iteration, int): + self._printer.print( + content="Invalid or missing train iteration. Cannot save training data.", + color="red", + ) + return - # Load training data training_handler = CrewTrainingHandler(TRAINING_DATA_FILE) - training_data = training_handler.load() + training_data = training_handler.load() or {} - # Check if training data exists, human input is not requested, and self.crew is valid - if training_data and not self.ask_for_human_input: - if self.crew is not None and hasattr(self.crew, "_train_iteration"): - train_iteration = self.crew._train_iteration - if agent_id in training_data and isinstance(train_iteration, int): - training_data[agent_id][train_iteration][ - "improved_output" - ] = result.output - training_handler.save(training_data) - else: - self._printer.print( - content="Invalid train iteration type or agent_id not in training data.", - color="red", - ) - else: - self._printer.print( - content="Crew is None or does not have _train_iteration attribute.", - color="red", - ) + # Initialize or retrieve agent's training data + agent_training_data = training_data.get(agent_id, {}) - if self.ask_for_human_input and human_feedback is not None: - training_data = { + if human_feedback is not None: + # Save initial output and human feedback + agent_training_data[train_iteration] = { "initial_output": result.output, "human_feedback": human_feedback, - "agent": agent_id, - "agent_role": self.agent.role, # type: ignore } - if self.crew is not None and hasattr(self.crew, "_train_iteration"): - train_iteration = self.crew._train_iteration - if isinstance(train_iteration, int): - CrewTrainingHandler(TRAINING_DATA_FILE).append( - train_iteration, agent_id, training_data - ) - else: - self._printer.print( - content="Invalid train iteration type. Expected int.", - color="red", - ) + else: + # Save improved output + if train_iteration in agent_training_data: + agent_training_data[train_iteration]["improved_output"] = result.output else: self._printer.print( - content="Crew is None or does not have _train_iteration attribute.", + content=( + f"No existing training data for agent {agent_id} and iteration " + f"{train_iteration}. Cannot save improved output." + ), color="red", ) + return + + # Update the training data and save + training_data[agent_id] = agent_training_data + training_handler.save(training_data) def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str: prompt = prompt.replace("{input}", inputs["input"]) @@ -485,82 +491,103 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return {"role": role, "content": prompt} def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish: - """ - Handles the human feedback loop, allowing the user to provide feedback - on the agent's output and determining if additional iterations are needed. + """Handle human feedback with different flows for training vs regular use. - Parameters: - formatted_answer (AgentFinish): The initial output from the agent. + Args: + formatted_answer: The initial AgentFinish result to get feedback on Returns: - AgentFinish: The final output after incorporating human feedback. + AgentFinish: The final answer after processing feedback """ + human_feedback = self._ask_human_input(formatted_answer.output) + + if self._is_training_mode(): + return self._handle_training_feedback(formatted_answer, human_feedback) + + return self._handle_regular_feedback(formatted_answer, human_feedback) + + def _is_training_mode(self) -> bool: + """Check if crew is in training mode.""" + return bool(self.crew and self.crew._train) + + def _handle_training_feedback( + self, initial_answer: AgentFinish, feedback: str + ) -> AgentFinish: + """Process feedback for training scenarios with single iteration.""" + self._printer.print( + content="\nProcessing training feedback.\n", + color="yellow", + ) + self._handle_crew_training_output(initial_answer, feedback) + self.messages.append(self._format_msg(f"Feedback: {feedback}")) + improved_answer = self._invoke_loop() + self._handle_crew_training_output(improved_answer) + self.ask_for_human_input = False + return improved_answer + + def _handle_regular_feedback( + self, current_answer: AgentFinish, initial_feedback: str + ) -> AgentFinish: + """Process feedback for regular use with potential multiple iterations.""" + feedback = initial_feedback + answer = current_answer + while self.ask_for_human_input: - human_feedback = self._ask_human_input(formatted_answer.output) + response = self._get_llm_feedback_response(feedback) - if self.crew and self.crew._train: - self._handle_crew_training_output(formatted_answer, human_feedback) - - # Make an LLM call to verify if additional changes are requested based on human feedback - additional_changes_prompt = self._i18n.slice( - "human_feedback_classification" - ).format(feedback=human_feedback) - - retry_count = 0 - llm_call_successful = False - additional_changes_response = None - - while retry_count < MAX_LLM_RETRY and not llm_call_successful: - try: - additional_changes_response = ( - self.llm.call( - [ - self._format_msg( - additional_changes_prompt, role="system" - ) - ], - callbacks=self.callbacks, - ) - .strip() - .lower() - ) - llm_call_successful = True - except Exception as e: - retry_count += 1 - - self._printer.print( - content=f"Error during LLM call to classify human feedback: {e}. Retrying... ({retry_count}/{MAX_LLM_RETRY})", - color="red", - ) - - if not llm_call_successful: - self._printer.print( - content="Error processing feedback after multiple attempts.", - color="red", - ) + if not self._feedback_requires_changes(response): self.ask_for_human_input = False - break - - if additional_changes_response == "false": - self.ask_for_human_input = False - elif additional_changes_response == "true": - self.ask_for_human_input = True - # Add human feedback to messages - self.messages.append(self._format_msg(f"Feedback: {human_feedback}")) - # Invoke the loop again with updated messages - formatted_answer = self._invoke_loop() - - if self.crew and self.crew._train: - self._handle_crew_training_output(formatted_answer) else: - # Unexpected response - self._printer.print( - content=f"Unexpected response from LLM: '{additional_changes_response}'. Assuming no additional changes requested.", - color="red", - ) - self.ask_for_human_input = False + answer = self._process_feedback_iteration(feedback) + feedback = self._ask_human_input(answer.output) - return formatted_answer + return answer + + def _get_llm_feedback_response(self, feedback: str) -> Optional[str]: + """Get LLM classification of whether feedback requires changes.""" + prompt = self._i18n.slice("human_feedback_classification").format( + feedback=feedback + ) + message = self._format_msg(prompt, role="system") + + for retry in range(MAX_LLM_RETRY): + try: + response = self.llm.call([message], callbacks=self.callbacks) + return response.strip().lower() if response else None + except Exception as error: + self._log_feedback_error(retry, error) + + self._log_max_retries_exceeded() + return None + + def _feedback_requires_changes(self, response: Optional[str]) -> bool: + """Determine if feedback response indicates need for changes.""" + return response == "true" if response else False + + def _process_feedback_iteration(self, feedback: str) -> AgentFinish: + """Process a single feedback iteration.""" + self.messages.append(self._format_msg(f"Feedback: {feedback}")) + return self._invoke_loop() + + def _log_feedback_error(self, retry_count: int, error: Exception) -> None: + """Log feedback processing errors.""" + self._printer.print( + content=( + f"Error processing feedback: {error}. " + f"Retrying... ({retry_count + 1}/{MAX_LLM_RETRY})" + ), + color="red", + ) + + def _log_max_retries_exceeded(self) -> None: + """Log when max retries for feedback processing are exceeded.""" + self._printer.print( + content=( + f"Failed to process feedback after {MAX_LLM_RETRY} attempts. " + "Ending feedback loop." + ), + color="red", + ) def _handle_max_iterations_exceeded(self, formatted_answer): """ diff --git a/src/crewai/crew.py b/src/crewai/crew.py index b44667042..93987f3b8 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -494,21 +494,26 @@ class Crew(BaseModel): train_crew = self.copy() train_crew._setup_for_training(filename) - for n_iteration in range(n_iterations): - train_crew._train_iteration = n_iteration - train_crew.kickoff(inputs=inputs) + try: + for n_iteration in range(n_iterations): + train_crew._train_iteration = n_iteration + train_crew.kickoff(inputs=inputs) - training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() + training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() - for agent in train_crew.agents: - if training_data.get(str(agent.id)): - result = TaskEvaluator(agent).evaluate_training_data( - training_data=training_data, agent_id=str(agent.id) - ) - - CrewTrainingHandler(filename).save_trained_data( - agent_id=str(agent.role), trained_data=result.model_dump() - ) + for agent in train_crew.agents: + if training_data.get(str(agent.id)): + result = TaskEvaluator(agent).evaluate_training_data( + training_data=training_data, agent_id=str(agent.id) + ) + CrewTrainingHandler(filename).save_trained_data( + agent_id=str(agent.role), trained_data=result.model_dump() + ) + except Exception as e: + self._logger.log("error", f"Training failed: {e}", color="red") + CrewTrainingHandler(TRAINING_DATA_FILE).clear() + CrewTrainingHandler(filename).clear() + raise def kickoff( self, diff --git a/src/crewai/utilities/evaluators/task_evaluator.py b/src/crewai/utilities/evaluators/task_evaluator.py index acfdceed6..294629274 100644 --- a/src/crewai/utilities/evaluators/task_evaluator.py +++ b/src/crewai/utilities/evaluators/task_evaluator.py @@ -92,13 +92,34 @@ class TaskEvaluator: """ output_training_data = training_data[agent_id] - final_aggregated_data = "" - for _, data in output_training_data.items(): + + for iteration, data in output_training_data.items(): + improved_output = data.get("improved_output") + initial_output = data.get("initial_output") + human_feedback = data.get("human_feedback") + + if not all([improved_output, initial_output, human_feedback]): + missing_fields = [ + field + for field in ["improved_output", "initial_output", "human_feedback"] + if not data.get(field) + ] + error_msg = ( + f"Critical training data error: Missing fields ({', '.join(missing_fields)}) " + f"for agent {agent_id} in iteration {iteration}.\n" + "This indicates a broken training process. " + "Cannot proceed with evaluation.\n" + "Please check your training implementation." + ) + raise ValueError(error_msg) + final_aggregated_data += ( - f"Initial Output:\n{data.get('initial_output', '')}\n\n" - f"Human Feedback:\n{data.get('human_feedback', '')}\n\n" - f"Improved Output:\n{data.get('improved_output', '')}\n\n" + f"Iteration: {iteration}\n" + f"Initial Output:\n{initial_output}\n\n" + f"Human Feedback:\n{human_feedback}\n\n" + f"Improved Output:\n{improved_output}\n\n" + "------------------------------------------------\n\n" ) evaluation_query = ( diff --git a/src/crewai/utilities/training_handler.py b/src/crewai/utilities/training_handler.py index 5cadde619..b6b3c38b6 100644 --- a/src/crewai/utilities/training_handler.py +++ b/src/crewai/utilities/training_handler.py @@ -1,3 +1,5 @@ +import os + from crewai.utilities.file_handler import PickleHandler @@ -29,3 +31,10 @@ class CrewTrainingHandler(PickleHandler): data[agent_id] = {train_iteration: new_data} self.save(data) + + def clear(self) -> None: + """Clear the training data by removing the file or resetting its contents.""" + if os.path.exists(self.file_path): + with open(self.file_path, "wb") as file: + # Overwrite with an empty dictionary + self.save({}) diff --git a/tests/utilities/evaluators/test_task_evaluator.py b/tests/utilities/evaluators/test_task_evaluator.py index 8a0be027a..e4de1db62 100644 --- a/tests/utilities/evaluators/test_task_evaluator.py +++ b/tests/utilities/evaluators/test_task_evaluator.py @@ -48,9 +48,9 @@ def test_evaluate_training_data(converter_mock): mock.call( llm=original_agent.llm, text="Assess the quality of the training data based on the llm output, human feedback , and llm " - "output improved result.\n\nInitial Output:\nInitial output 1\n\nHuman Feedback:\nHuman feedback " - "1\n\nImproved Output:\nImproved output 1\n\nInitial Output:\nInitial output 2\n\nHuman " - "Feedback:\nHuman feedback 2\n\nImproved Output:\nImproved output 2\n\nPlease provide:\n- Provide " + "output improved result.\n\nIteration: data1\nInitial Output:\nInitial output 1\n\nHuman Feedback:\nHuman feedback " + "1\n\nImproved Output:\nImproved output 1\n\n------------------------------------------------\n\nIteration: data2\nInitial Output:\nInitial output 2\n\nHuman " + "Feedback:\nHuman feedback 2\n\nImproved Output:\nImproved output 2\n\n------------------------------------------------\n\nPlease provide:\n- Provide " "a list of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's " "performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific " "action items for future tasks. Ensure all key and specificpoints from the human feedback are " From 477cce321fe3fe4c8c40196e098666f3f27ce5b4 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Wed, 29 Jan 2025 19:41:09 -0500 Subject: [PATCH 4/6] Fix llms (#2003) * iwp * add in api_base --------- Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> --- src/crewai/llm.py | 5 ++++- src/crewai/utilities/llm_utils.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 98b0bc855..ef8746fd5 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -133,6 +133,7 @@ class LLM: logprobs: Optional[int] = None, top_logprobs: Optional[int] = None, base_url: Optional[str] = None, + api_base: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, callbacks: List[Any] = [], @@ -152,6 +153,7 @@ class LLM: self.logprobs = logprobs self.top_logprobs = top_logprobs self.base_url = base_url + self.api_base = api_base self.api_version = api_version self.api_key = api_key self.callbacks = callbacks @@ -232,7 +234,8 @@ class LLM: "seed": self.seed, "logprobs": self.logprobs, "top_logprobs": self.top_logprobs, - "api_base": self.base_url, + "api_base": self.api_base, + "base_url": self.base_url, "api_version": self.api_version, "api_key": self.api_key, "stream": False, diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 13230edf6..c774a71fb 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -53,6 +53,7 @@ def create_llm( timeout: Optional[float] = getattr(llm_value, "timeout", None) api_key: Optional[str] = getattr(llm_value, "api_key", None) base_url: Optional[str] = getattr(llm_value, "base_url", None) + api_base: Optional[str] = getattr(llm_value, "api_base", None) created_llm = LLM( model=model, @@ -62,6 +63,7 @@ def create_llm( timeout=timeout, api_key=api_key, base_url=base_url, + api_base=api_base, ) return created_llm except Exception as e: @@ -101,8 +103,18 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: callbacks: List[Any] = [] # Optional base URL from env - api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get("OPENAI_BASE_URL") - if api_base: + base_url = ( + os.environ.get("BASE_URL") + or os.environ.get("OPENAI_API_BASE") + or os.environ.get("OPENAI_BASE_URL") + ) + + api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE") + + # Synchronize base_url and api_base if one is populated and the other is not + if base_url and not api_base: + api_base = base_url + elif api_base and not base_url: base_url = api_base # Initialize llm_params dictionary @@ -115,6 +127,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: "timeout": timeout, "api_key": api_key, "base_url": base_url, + "api_base": api_base, "api_version": api_version, "presence_penalty": presence_penalty, "frequency_penalty": frequency_penalty, From ddb7958da7d24336e152d9b6f34aa0a3bcc04221 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Thu, 30 Jan 2025 18:16:10 -0500 Subject: [PATCH 5/6] Clean up to match enterprise (#2009) * Clean up to match enterprise * improve feedback prompting --- src/crewai/agents/crew_agent_executor.py | 12 ++++++++++-- src/crewai/translations/en.json | 3 ++- src/crewai/utilities/training_handler.py | 4 +--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index b144872b1..ed89008fd 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -519,7 +519,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): color="yellow", ) self._handle_crew_training_output(initial_answer, feedback) - self.messages.append(self._format_msg(f"Feedback: {feedback}")) + self.messages.append( + self._format_msg( + self._i18n.slice("feedback_instructions").format(feedback=feedback) + ) + ) improved_answer = self._invoke_loop() self._handle_crew_training_output(improved_answer) self.ask_for_human_input = False @@ -566,7 +570,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): def _process_feedback_iteration(self, feedback: str) -> AgentFinish: """Process a single feedback iteration.""" - self.messages.append(self._format_msg(f"Feedback: {feedback}")) + self.messages.append( + self._format_msg( + self._i18n.slice("feedback_instructions").format(feedback=feedback) + ) + ) return self._invoke_loop() def _log_feedback_error(self, retry_count: int, error: Exception) -> None: diff --git a/src/crewai/translations/en.json b/src/crewai/translations/en.json index 6385d5862..0c45321ea 100644 --- a/src/crewai/translations/en.json +++ b/src/crewai/translations/en.json @@ -24,7 +24,8 @@ "manager_request": "Your best answer to your coworker asking you this, accounting for the context shared.", "formatted_task_instructions": "Ensure your final answer contains only the content in the following format: {output_format}\n\nEnsure the final output does not include any code block markers like ```json or ```python.", "human_feedback_classification": "Determine if the following feedback indicates that the user is satisfied or if further changes are needed. Respond with 'True' if further changes are needed, or 'False' if the user is satisfied. **Important** Do not include any additional commentary outside of your 'True' or 'False' response.\n\nFeedback: \"{feedback}\"", - "conversation_history_instruction": "You are a member of a crew collaborating to achieve a common goal. Your task is a specific action that contributes to this larger objective. For additional context, please review the conversation history between you and the user that led to the initiation of this crew. Use any relevant information or feedback from the conversation to inform your task execution and ensure your response aligns with both the immediate task and the crew's overall goals." + "conversation_history_instruction": "You are a member of a crew collaborating to achieve a common goal. Your task is a specific action that contributes to this larger objective. For additional context, please review the conversation history between you and the user that led to the initiation of this crew. Use any relevant information or feedback from the conversation to inform your task execution and ensure your response aligns with both the immediate task and the crew's overall goals.", + "feedback_instructions": "User feedback: {feedback}\nInstructions: Use this feedback to enhance the next output iteration.\nNote: Do not respond or add commentary." }, "errors": { "force_final_answer_error": "You can't keep going, here is the best final answer you generated:\n\n {formatted_answer}", diff --git a/src/crewai/utilities/training_handler.py b/src/crewai/utilities/training_handler.py index b6b3c38b6..2d34f3261 100644 --- a/src/crewai/utilities/training_handler.py +++ b/src/crewai/utilities/training_handler.py @@ -35,6 +35,4 @@ class CrewTrainingHandler(PickleHandler): def clear(self) -> None: """Clear the training data by removing the file or resetting its contents.""" if os.path.exists(self.file_path): - with open(self.file_path, "wb") as file: - # Overwrite with an empty dictionary - self.save({}) + self.save({}) From 23b9e1032308d81aa16072932ea7884d44bc4da1 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Fri, 31 Jan 2025 12:53:58 -0500 Subject: [PATCH 6/6] Brandon/provide llm additional params (#2018) * Clean up to match enterprise * add additional params to LLM calls * make sure additional params are getting passed to llm * update docs * drop print --- docs/concepts/llms.mdx | 13 +++++++++++- src/crewai/llm.py | 3 +++ tests/llm_test.py | 48 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/docs/concepts/llms.mdx b/docs/concepts/llms.mdx index 261a1fdd8..0358308f4 100644 --- a/docs/concepts/llms.mdx +++ b/docs/concepts/llms.mdx @@ -465,11 +465,22 @@ Learn how to get the most out of your LLM configuration: # https://cloud.google.com/vertex-ai/generative-ai/docs/overview ``` + ## GET CREDENTIALS + file_path = 'path/to/vertex_ai_service_account.json' + + # Load the JSON file + with open(file_path, 'r') as file: + vertex_credentials = json.load(file) + + # Convert to JSON string + vertex_credentials_json = json.dumps(vertex_credentials) + Example usage: ```python Code llm = LLM( model="gemini/gemini-1.5-pro-latest", - temperature=0.7 + temperature=0.7, + vertex_credentials=vertex_credentials_json ) ``` diff --git a/src/crewai/llm.py b/src/crewai/llm.py index ef8746fd5..bbf8e35d9 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -137,6 +137,7 @@ class LLM: api_version: Optional[str] = None, api_key: Optional[str] = None, callbacks: List[Any] = [], + **kwargs, ): self.model = model self.timeout = timeout @@ -158,6 +159,7 @@ class LLM: self.api_key = api_key self.callbacks = callbacks self.context_window_size = 0 + self.additional_params = kwargs litellm.drop_params = True @@ -240,6 +242,7 @@ class LLM: "api_key": self.api_key, "stream": False, "tools": tools, + **self.additional_params, } # Remove None values from params diff --git a/tests/llm_test.py b/tests/llm_test.py index 6d1e6a188..8db8726d0 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -1,4 +1,5 @@ from time import sleep +from unittest.mock import MagicMock, patch import pytest @@ -154,3 +155,50 @@ def test_llm_call_with_tool_and_message_list(): assert isinstance(result, int) assert result == 25 + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_llm_passes_additional_params(): + llm = LLM( + model="gpt-4o-mini", + vertex_credentials="test_credentials", + vertex_project="test_project", + ) + + messages = [{"role": "user", "content": "Hello, world!"}] + + with patch("litellm.completion") as mocked_completion: + # Create mocks for response structure + mock_message = MagicMock() + mock_message.content = "Test response" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 5, + "completion_tokens": 5, + "total_tokens": 10, + } + + # Set up the mocked completion to return the mock response + mocked_completion.return_value = mock_response + + result = llm.call(messages) + + # Assert that litellm.completion was called once + mocked_completion.assert_called_once() + + # Retrieve the actual arguments with which litellm.completion was called + _, kwargs = mocked_completion.call_args + + # Check that the additional_params were passed to litellm.completion + assert kwargs["vertex_credentials"] == "test_credentials" + assert kwargs["vertex_project"] == "test_project" + + # Also verify that other expected parameters are present + assert kwargs["model"] == "gpt-4o-mini" + assert kwargs["messages"] == messages + + # Check the result from llm.call + assert result == "Test response"