fix: Add validation fix output_file issue when have '/' (#585)

* fix: Add validation fix output_file issue when have /

* fix: run black to format code

* fix: run black to format code
This commit is contained in:
Eduardo Chiarotti
2024-05-09 08:11:00 -03:00
committed by GitHub
parent 809b4b227c
commit 7eb4fcdaf4
4 changed files with 31 additions and 19 deletions

View File

@@ -303,9 +303,9 @@ class Agent(BaseModel):
} }
if self._rpm_controller: if self._rpm_controller:
executor_args[ executor_args["request_within_rpm_limit"] = (
"request_within_rpm_limit" self._rpm_controller.check_or_wait
] = self._rpm_controller.check_or_wait )
prompt = Prompts( prompt = Prompts(
i18n=self.i18n, i18n=self.i18n,

View File

@@ -3,7 +3,9 @@ from crewai_tools import BaseTool
class MyCustomTool(BaseTool): class MyCustomTool(BaseTool):
name: str = "Name of my tool" name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." description: str = (
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
def _run(self, argument: str) -> str: def _run(self, argument: str) -> str:
# Implementation goes here # Implementation goes here

View File

@@ -1,8 +1,8 @@
import os
import re import re
import threading import threading
import uuid import uuid
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
import os
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic import UUID4, BaseModel, Field, field_validator, model_validator from pydantic import UUID4, BaseModel, Field, field_validator, model_validator
@@ -109,6 +109,14 @@ class Task(BaseModel):
"may_not_set_field", "This field is not to be set by the user.", {} "may_not_set_field", "This field is not to be set by the user.", {}
) )
@field_validator("output_file")
@classmethod
def output_file_validattion(cls, value: str) -> str:
"""Validate the output file path by removing the / from the beginning of the path."""
if value.startswith("/"):
return value[1:]
return value
@model_validator(mode="after") @model_validator(mode="after")
def set_attributes_based_on_config(self) -> "Task": def set_attributes_based_on_config(self) -> "Task":
"""Set attributes based on the agent configuration.""" """Set attributes based on the agent configuration."""
@@ -247,16 +255,16 @@ class Task(BaseModel):
return exported_result.model_dump() return exported_result.model_dump()
return exported_result return exported_result
except Exception: except Exception:
# sometimes the response contains valid JSON in the middle of text # sometimes the response contains valid JSON in the middle of text
match = re.search(r"({.*})", result, re.DOTALL) match = re.search(r"({.*})", result, re.DOTALL)
if match: if match:
try: try:
exported_result = model.model_validate_json(match.group(0)) exported_result = model.model_validate_json(match.group(0))
if self.output_json: if self.output_json:
return exported_result.model_dump() return exported_result.model_dump()
return exported_result return exported_result
except Exception: except Exception:
pass pass
llm = self.agent.function_calling_llm or self.agent.llm llm = self.agent.function_calling_llm or self.agent.llm
@@ -294,7 +302,7 @@ class Task(BaseModel):
def _save_file(self, result: Any) -> None: def _save_file(self, result: Any) -> None:
directory = os.path.dirname(self.output_file) directory = os.path.dirname(self.output_file)
if not os.path.exists(directory): if directory and not os.path.exists(directory):
os.makedirs(directory) os.makedirs(directory)
with open(self.output_file, "w") as file: with open(self.output_file, "w") as file:

View File

@@ -256,9 +256,11 @@ class Telemetry:
"async_execution?": task.async_execution, "async_execution?": task.async_execution,
"output": task.expected_output, "output": task.expected_output,
"agent_role": task.agent.role if task.agent else "None", "agent_role": task.agent.role if task.agent else "None",
"context": [task.description for task in task.context] "context": (
if task.context [task.description for task in task.context]
else "None", if task.context
else "None"
),
"tools_names": [ "tools_names": [
tool.name.casefold() for tool in task.tools tool.name.casefold() for tool in task.tools
], ],