mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
fix: restore type[BaseModel] | None on CrewStructuredTool.args_schema
This commit is contained in:
@@ -61,7 +61,7 @@ class CrewStructuredTool(BaseModel):
|
|||||||
|
|
||||||
name: str = Field(default="")
|
name: str = Field(default="")
|
||||||
description: str = Field(default="")
|
description: str = Field(default="")
|
||||||
args_schema: Any = Field(default=None)
|
args_schema: type[BaseModel] | None = Field(default=None)
|
||||||
func: Any = Field(default=None, exclude=True)
|
func: Any = Field(default=None, exclude=True)
|
||||||
result_as_answer: bool = Field(default=False)
|
result_as_answer: bool = Field(default=False)
|
||||||
max_usage_count: int | None = Field(default=None)
|
max_usage_count: int | None = Field(default=None)
|
||||||
@@ -179,6 +179,8 @@ class CrewStructuredTool(BaseModel):
|
|||||||
|
|
||||||
def _validate_function_signature(self) -> None:
|
def _validate_function_signature(self) -> None:
|
||||||
"""Validate that the function signature matches the args schema."""
|
"""Validate that the function signature matches the args schema."""
|
||||||
|
if not self.args_schema:
|
||||||
|
return
|
||||||
sig = inspect.signature(self.func)
|
sig = inspect.signature(self.func)
|
||||||
schema_fields = self.args_schema.model_fields
|
schema_fields = self.args_schema.model_fields
|
||||||
|
|
||||||
@@ -218,6 +220,8 @@ class CrewStructuredTool(BaseModel):
|
|||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
|
raise ValueError(f"Failed to parse arguments as JSON: {e}") from e
|
||||||
|
|
||||||
|
if not self.args_schema:
|
||||||
|
return raw_args if isinstance(raw_args, dict) else {}
|
||||||
try:
|
try:
|
||||||
validated_args = self.args_schema.model_validate(raw_args)
|
validated_args = self.args_schema.model_validate(raw_args)
|
||||||
return dict(validated_args.model_dump())
|
return dict(validated_args.model_dump())
|
||||||
@@ -265,6 +269,8 @@ class CrewStructuredTool(BaseModel):
|
|||||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""Legacy method for compatibility."""
|
"""Legacy method for compatibility."""
|
||||||
# Convert args/kwargs to our expected format
|
# Convert args/kwargs to our expected format
|
||||||
|
if not self.args_schema:
|
||||||
|
return self.func(*args, **kwargs)
|
||||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
||||||
input_dict.update(kwargs)
|
input_dict.update(kwargs)
|
||||||
return self.invoke(input_dict)
|
return self.invoke(input_dict)
|
||||||
@@ -311,6 +317,8 @@ class CrewStructuredTool(BaseModel):
|
|||||||
@property
|
@property
|
||||||
def args(self) -> dict[str, Any]:
|
def args(self) -> dict[str, Any]:
|
||||||
"""Get the tool's input arguments schema."""
|
"""Get the tool's input arguments schema."""
|
||||||
|
if not self.args_schema:
|
||||||
|
return {}
|
||||||
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
|
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user