fix: Add validation fix output_file issue when have '/' (#585)

* fix: Add validation fix output_file issue when have /

* fix: run black to format code

* fix: run black to format code
This commit is contained in:
Eduardo Chiarotti
2024-05-09 08:11:00 -03:00
committed by GitHub
parent 402c5f477b
commit 8063e1d154
4 changed files with 31 additions and 19 deletions

View File

@@ -303,9 +303,9 @@ class Agent(BaseModel):
}
if self._rpm_controller:
executor_args[
"request_within_rpm_limit"
] = self._rpm_controller.check_or_wait
executor_args["request_within_rpm_limit"] = (
self._rpm_controller.check_or_wait
)
prompt = Prompts(
i18n=self.i18n,

View File

@@ -3,7 +3,9 @@ from crewai_tools import BaseTool
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
description: str = (
"Clear description for what this tool is useful for, you agent will need this information to use it."
)
def _run(self, argument: str) -> str:
# Implementation goes here

View File

@@ -1,8 +1,8 @@
import os
import re
import threading
import uuid
from typing import Any, Dict, List, Optional, Type
import os
from langchain_openai import ChatOpenAI
from pydantic import UUID4, BaseModel, Field, field_validator, model_validator
@@ -109,6 +109,14 @@ class Task(BaseModel):
"may_not_set_field", "This field is not to be set by the user.", {}
)
@field_validator("output_file")
@classmethod
def output_file_validattion(cls, value: str) -> str:
"""Validate the output file path by removing the / from the beginning of the path."""
if value.startswith("/"):
return value[1:]
return value
@model_validator(mode="after")
def set_attributes_based_on_config(self) -> "Task":
"""Set attributes based on the agent configuration."""
@@ -247,16 +255,16 @@ class Task(BaseModel):
return exported_result.model_dump()
return exported_result
except Exception:
# sometimes the response contains valid JSON in the middle of text
match = re.search(r"({.*})", result, re.DOTALL)
if match:
try:
exported_result = model.model_validate_json(match.group(0))
if self.output_json:
return exported_result.model_dump()
return exported_result
except Exception:
pass
# sometimes the response contains valid JSON in the middle of text
match = re.search(r"({.*})", result, re.DOTALL)
if match:
try:
exported_result = model.model_validate_json(match.group(0))
if self.output_json:
return exported_result.model_dump()
return exported_result
except Exception:
pass
llm = self.agent.function_calling_llm or self.agent.llm
@@ -294,7 +302,7 @@ class Task(BaseModel):
def _save_file(self, result: Any) -> None:
directory = os.path.dirname(self.output_file)
if not os.path.exists(directory):
if directory and not os.path.exists(directory):
os.makedirs(directory)
with open(self.output_file, "w") as file:

View File

@@ -256,9 +256,11 @@ class Telemetry:
"async_execution?": task.async_execution,
"output": task.expected_output,
"agent_role": task.agent.role if task.agent else "None",
"context": [task.description for task in task.context]
if task.context
else "None",
"context": (
[task.description for task in task.context]
if task.context
else "None"
),
"tools_names": [
tool.name.casefold() for tool in task.tools
],