mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Merge branch 'main' into lorenze/agent-executor-flow-pattern
This commit is contained in:
@@ -187,6 +187,97 @@ You can also deploy your crews directly through the CrewAI AOP web interface by
|
|||||||
|
|
||||||
</Steps>
|
</Steps>
|
||||||
|
|
||||||
|
## Option 3: Redeploy Using API (CI/CD Integration)
|
||||||
|
|
||||||
|
For automated deployments in CI/CD pipelines, you can use the CrewAI API to trigger redeployments of existing crews. This is particularly useful for GitHub Actions, Jenkins, or other automation workflows.
|
||||||
|
|
||||||
|
<Steps>
|
||||||
|
<Step title="Get Your Personal Access Token">
|
||||||
|
|
||||||
|
Navigate to your CrewAI AOP account settings to generate an API token:
|
||||||
|
|
||||||
|
1. Go to [app.crewai.com](https://app.crewai.com)
|
||||||
|
2. Click on **Settings** → **Account** → **Personal Access Token**
|
||||||
|
3. Generate a new token and copy it securely
|
||||||
|
4. Store this token as a secret in your CI/CD system
|
||||||
|
|
||||||
|
</Step>
|
||||||
|
|
||||||
|
<Step title="Find Your Automation UUID">
|
||||||
|
|
||||||
|
Locate the unique identifier for your deployed crew:
|
||||||
|
|
||||||
|
1. Go to **Automations** in your CrewAI AOP dashboard
|
||||||
|
2. Select your existing automation/crew
|
||||||
|
3. Click on **Additional Details**
|
||||||
|
4. Copy the **UUID** - this identifies your specific crew deployment
|
||||||
|
|
||||||
|
</Step>
|
||||||
|
|
||||||
|
<Step title="Trigger Redeployment via API">
|
||||||
|
|
||||||
|
Use the Deploy API endpoint to trigger a redeployment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -i -X POST \
|
||||||
|
-H "Authorization: Bearer YOUR_PERSONAL_ACCESS_TOKEN" \
|
||||||
|
https://app.crewai.com/crewai_plus/api/v1/crews/YOUR-AUTOMATION-UUID/deploy
|
||||||
|
|
||||||
|
# HTTP/2 200
|
||||||
|
# content-type: application/json
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "uuid": "your-automation-uuid",
|
||||||
|
# "status": "Deploy Enqueued",
|
||||||
|
# "public_url": "https://your-crew-deployment.crewai.com",
|
||||||
|
# "token": "your-bearer-token"
|
||||||
|
# }
|
||||||
|
```
|
||||||
|
|
||||||
|
<Info>
|
||||||
|
If your automation was first created connected to Git, the API will automatically pull the latest changes from your repository before redeploying.
|
||||||
|
</Info>
|
||||||
|
|
||||||
|
|
||||||
|
</Step>
|
||||||
|
|
||||||
|
<Step title="GitHub Actions Integration Example">
|
||||||
|
|
||||||
|
Here's a GitHub Actions workflow with more complex deployment triggers:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: Deploy CrewAI Automation
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
types: [ labeled ]
|
||||||
|
release:
|
||||||
|
types: [ published ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: |
|
||||||
|
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
|
||||||
|
(github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')) ||
|
||||||
|
(github.event_name == 'release')
|
||||||
|
steps:
|
||||||
|
- name: Trigger CrewAI Redeployment
|
||||||
|
run: |
|
||||||
|
curl -X POST \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.CREWAI_PAT }}" \
|
||||||
|
https://app.crewai.com/crewai_plus/api/v1/crews/${{ secrets.CREWAI_AUTOMATION_UUID }}/deploy
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
Add `CREWAI_PAT` and `CREWAI_AUTOMATION_UUID` as repository secrets. For PR deployments, add a "deploy" label to trigger the workflow.
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
</Step>
|
||||||
|
</Steps>
|
||||||
|
|
||||||
## ⚠️ Environment Variable Security Requirements
|
## ⚠️ Environment Variable Security Requirements
|
||||||
|
|
||||||
<Warning>
|
<Warning>
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from crewai.events.types.knowledge_events import (
|
|||||||
KnowledgeSearchQueryFailedEvent,
|
KnowledgeSearchQueryFailedEvent,
|
||||||
)
|
)
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Final, Literal
|
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||||
|
|
||||||
from crewai.utilities.converter import generate_model_description
|
|
||||||
|
|
||||||
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -42,7 +41,7 @@ class BaseConverterAdapter(ABC):
|
|||||||
"""
|
"""
|
||||||
self.agent_adapter = agent_adapter
|
self.agent_adapter = agent_adapter
|
||||||
self._output_format: Literal["json", "pydantic"] | None = None
|
self._output_format: Literal["json", "pydantic"] | None = None
|
||||||
self._schema: str | None = None
|
self._schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure_structured_output(self, task: Task) -> None:
|
def configure_structured_output(self, task: Task) -> None:
|
||||||
@@ -129,7 +128,7 @@ class BaseConverterAdapter(ABC):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_format_from_task(
|
def _configure_format_from_task(
|
||||||
task: Task,
|
task: Task,
|
||||||
) -> tuple[Literal["json", "pydantic"] | None, str | None]:
|
) -> tuple[Literal["json", "pydantic"] | None, dict[str, Any] | None]:
|
||||||
"""Determine output format and schema from task requirements.
|
"""Determine output format and schema from task requirements.
|
||||||
|
|
||||||
This is a helper method that examines the task's output requirements
|
This is a helper method that examines the task's output requirements
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ This module contains the OpenAIConverterAdapter class that handles structured
|
|||||||
output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
|
output conversion for OpenAI agents, supporting JSON and Pydantic model formats.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
from crewai.agents.agent_adapters.base_converter_adapter import BaseConverterAdapter
|
||||||
@@ -61,7 +62,7 @@ class OpenAIConverterAdapter(BaseConverterAdapter):
|
|||||||
output_schema: str = (
|
output_schema: str = (
|
||||||
get_i18n()
|
get_i18n()
|
||||||
.slice("formatted_task_instructions")
|
.slice("formatted_task_instructions")
|
||||||
.format(output_format=self._schema)
|
.format(output_format=json.dumps(self._schema, indent=2))
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"{base_prompt}\n\n{output_schema}"
|
return f"{base_prompt}\n\n{output_schema}"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
from json import JSONDecodeError
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -162,9 +163,19 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
|
|
||||||
if login_response.status_code != 200:
|
if login_response.status_code != 200:
|
||||||
console.print(
|
console.print(
|
||||||
"Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ",
|
"Authentication failed. Verify if the currently active organization can access the tool repository, and run 'crewai login' again.",
|
||||||
style="bold red",
|
style="bold red",
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
console.print(
|
||||||
|
f"[{login_response.status_code} error - {login_response.json().get('message', 'Unknown error')}]",
|
||||||
|
style="bold red italic",
|
||||||
|
)
|
||||||
|
except JSONDecodeError:
|
||||||
|
console.print(
|
||||||
|
f"[{login_response.status_code} error - Unknown error - Invalid JSON response]",
|
||||||
|
style="bold red italic",
|
||||||
|
)
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
login_response_json = login_response.json()
|
login_response_json = login_response.json()
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from pydantic import BaseModel
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||||
from crewai.utilities.converter import generate_model_description
|
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededError,
|
LLMContextLengthExceededError,
|
||||||
)
|
)
|
||||||
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ from crewai.events.types.llm_events import LLMCallType
|
|||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||||
from crewai.utilities.converter import generate_model_description
|
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededError,
|
LLMContextLengthExceededError,
|
||||||
)
|
)
|
||||||
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -494,8 +494,11 @@ class Task(BaseModel):
|
|||||||
future: Future[TaskOutput],
|
future: Future[TaskOutput],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute the task asynchronously with context handling."""
|
"""Execute the task asynchronously with context handling."""
|
||||||
result = self._execute_core(agent, context, tools)
|
try:
|
||||||
future.set_result(result)
|
result = self._execute_core(agent, context, tools)
|
||||||
|
future.set_result(result)
|
||||||
|
except Exception as e:
|
||||||
|
future.set_exception(e)
|
||||||
|
|
||||||
async def aexecute_sync(
|
async def aexecute_sync(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -3,15 +3,13 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from inspect import signature
|
from inspect import Parameter, signature
|
||||||
|
import json
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Generic,
|
Generic,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,6 +25,7 @@ from typing_extensions import TypeIs
|
|||||||
|
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
|
|
||||||
|
|
||||||
_printer = Printer()
|
_printer = Printer()
|
||||||
@@ -103,20 +102,40 @@ class BaseTool(BaseModel, ABC):
|
|||||||
if v != cls._ArgsSchemaPlaceholder:
|
if v != cls._ArgsSchemaPlaceholder:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return cast(
|
run_sig = signature(cls._run)
|
||||||
type[PydanticBaseModel],
|
fields: dict[str, Any] = {}
|
||||||
type(
|
|
||||||
f"{cls.__name__}Schema",
|
for param_name, param in run_sig.parameters.items():
|
||||||
(PydanticBaseModel,),
|
if param_name in ("self", "return"):
|
||||||
{
|
continue
|
||||||
"__annotations__": {
|
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||||
k: v
|
continue
|
||||||
for k, v in cls._run.__annotations__.items()
|
|
||||||
if k != "return"
|
annotation = param.annotation if param.annotation != param.empty else Any
|
||||||
},
|
|
||||||
},
|
if param.default is param.empty:
|
||||||
),
|
fields[param_name] = (annotation, ...)
|
||||||
)
|
else:
|
||||||
|
fields[param_name] = (annotation, param.default)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
arun_sig = signature(cls._arun)
|
||||||
|
for param_name, param in arun_sig.parameters.items():
|
||||||
|
if param_name in ("self", "return"):
|
||||||
|
continue
|
||||||
|
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotation = (
|
||||||
|
param.annotation if param.annotation != param.empty else Any
|
||||||
|
)
|
||||||
|
|
||||||
|
if param.default is param.empty:
|
||||||
|
fields[param_name] = (annotation, ...)
|
||||||
|
else:
|
||||||
|
fields[param_name] = (annotation, param.default)
|
||||||
|
|
||||||
|
return create_model(f"{cls.__name__}Schema", **fields)
|
||||||
|
|
||||||
@field_validator("max_usage_count", mode="before")
|
@field_validator("max_usage_count", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -226,24 +245,23 @@ class BaseTool(BaseModel, ABC):
|
|||||||
args_schema = getattr(tool, "args_schema", None)
|
args_schema = getattr(tool, "args_schema", None)
|
||||||
|
|
||||||
if args_schema is None:
|
if args_schema is None:
|
||||||
# Infer args_schema from the function signature if not provided
|
|
||||||
func_signature = signature(tool.func)
|
func_signature = signature(tool.func)
|
||||||
annotations = func_signature.parameters
|
fields: dict[str, Any] = {}
|
||||||
args_fields: dict[str, Any] = {}
|
for name, param in func_signature.parameters.items():
|
||||||
for name, param in annotations.items():
|
if name == "self":
|
||||||
if name != "self":
|
continue
|
||||||
param_annotation = (
|
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||||
param.annotation if param.annotation != param.empty else Any
|
continue
|
||||||
)
|
param_annotation = (
|
||||||
field_info = Field(
|
param.annotation if param.annotation != param.empty else Any
|
||||||
default=...,
|
)
|
||||||
description="",
|
if param.default is param.empty:
|
||||||
)
|
fields[name] = (param_annotation, ...)
|
||||||
args_fields[name] = (param_annotation, field_info)
|
else:
|
||||||
if args_fields:
|
fields[name] = (param_annotation, param.default)
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
if fields:
|
||||||
|
args_schema = create_model(f"{tool.name}Input", **fields)
|
||||||
else:
|
else:
|
||||||
# Create a default schema with no fields if no parameters are found
|
|
||||||
args_schema = create_model(
|
args_schema = create_model(
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||||
)
|
)
|
||||||
@@ -257,53 +275,37 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
def _set_args_schema(self) -> None:
|
def _set_args_schema(self) -> None:
|
||||||
if self.args_schema is None:
|
if self.args_schema is None:
|
||||||
class_name = f"{self.__class__.__name__}Schema"
|
run_sig = signature(self._run)
|
||||||
self.args_schema = cast(
|
fields: dict[str, Any] = {}
|
||||||
type[PydanticBaseModel],
|
|
||||||
type(
|
for param_name, param in run_sig.parameters.items():
|
||||||
class_name,
|
if param_name in ("self", "return"):
|
||||||
(PydanticBaseModel,),
|
continue
|
||||||
{
|
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||||
"__annotations__": {
|
continue
|
||||||
k: v
|
|
||||||
for k, v in self._run.__annotations__.items()
|
annotation = (
|
||||||
if k != "return"
|
param.annotation if param.annotation != param.empty else Any
|
||||||
},
|
)
|
||||||
},
|
|
||||||
),
|
if param.default is param.empty:
|
||||||
|
fields[param_name] = (annotation, ...)
|
||||||
|
else:
|
||||||
|
fields[param_name] = (annotation, param.default)
|
||||||
|
|
||||||
|
self.args_schema = create_model(
|
||||||
|
f"{self.__class__.__name__}Schema", **fields
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_description(self) -> None:
|
def _generate_description(self) -> None:
|
||||||
args_schema = {
|
"""Generate the tool description with a JSON schema for arguments."""
|
||||||
name: {
|
schema = generate_model_description(self.args_schema)
|
||||||
"description": field.description,
|
args_json = json.dumps(schema["json_schema"]["schema"], indent=2)
|
||||||
"type": BaseTool._get_arg_annotations(field.annotation),
|
self.description = (
|
||||||
}
|
f"Tool Name: {self.name}\n"
|
||||||
for name, field in self.args_schema.model_fields.items()
|
f"Tool Arguments: {args_json}\n"
|
||||||
}
|
f"Tool Description: {self.description}"
|
||||||
|
)
|
||||||
self.description = f"Tool Name: {self.name}\nTool Arguments: {args_schema}\nTool Description: {self.description}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_arg_annotations(annotation: type[Any] | None) -> str:
|
|
||||||
if annotation is None:
|
|
||||||
return "None"
|
|
||||||
|
|
||||||
origin = get_origin(annotation)
|
|
||||||
args = get_args(annotation)
|
|
||||||
|
|
||||||
if origin is None:
|
|
||||||
return (
|
|
||||||
annotation.__name__
|
|
||||||
if hasattr(annotation, "__name__")
|
|
||||||
else str(annotation)
|
|
||||||
)
|
|
||||||
|
|
||||||
if args:
|
|
||||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
|
||||||
return str(f"{origin.__name__}[{args_str}]")
|
|
||||||
|
|
||||||
return str(origin.__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool, Generic[P, R]):
|
class Tool(BaseTool, Generic[P, R]):
|
||||||
@@ -406,24 +408,23 @@ class Tool(BaseTool, Generic[P, R]):
|
|||||||
args_schema = getattr(tool, "args_schema", None)
|
args_schema = getattr(tool, "args_schema", None)
|
||||||
|
|
||||||
if args_schema is None:
|
if args_schema is None:
|
||||||
# Infer args_schema from the function signature if not provided
|
|
||||||
func_signature = signature(tool.func)
|
func_signature = signature(tool.func)
|
||||||
annotations = func_signature.parameters
|
fields: dict[str, Any] = {}
|
||||||
args_fields: dict[str, Any] = {}
|
for name, param in func_signature.parameters.items():
|
||||||
for name, param in annotations.items():
|
if name == "self":
|
||||||
if name != "self":
|
continue
|
||||||
param_annotation = (
|
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||||
param.annotation if param.annotation != param.empty else Any
|
continue
|
||||||
)
|
param_annotation = (
|
||||||
field_info = Field(
|
param.annotation if param.annotation != param.empty else Any
|
||||||
default=...,
|
)
|
||||||
description="",
|
if param.default is param.empty:
|
||||||
)
|
fields[name] = (param_annotation, ...)
|
||||||
args_fields[name] = (param_annotation, field_info)
|
else:
|
||||||
if args_fields:
|
fields[name] = (param_annotation, param.default)
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
if fields:
|
||||||
|
args_schema = create_model(f"{tool.name}Input", **fields)
|
||||||
else:
|
else:
|
||||||
# Create a default schema with no fields if no parameters are found
|
|
||||||
args_schema = create_model(
|
args_schema = create_model(
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||||
)
|
)
|
||||||
@@ -502,32 +503,38 @@ def tool(
|
|||||||
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
|
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
|
||||||
if f.__doc__ is None:
|
if f.__doc__ is None:
|
||||||
raise ValueError("Function must have a docstring")
|
raise ValueError("Function must have a docstring")
|
||||||
|
if f.__annotations__ is None:
|
||||||
func_annotations = getattr(f, "__annotations__", None)
|
|
||||||
if func_annotations is None:
|
|
||||||
raise ValueError("Function must have type annotations")
|
raise ValueError("Function must have type annotations")
|
||||||
|
|
||||||
|
func_sig = signature(f)
|
||||||
|
fields: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for param_name, param in func_sig.parameters.items():
|
||||||
|
if param_name == "return":
|
||||||
|
continue
|
||||||
|
if param.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotation = (
|
||||||
|
param.annotation if param.annotation != param.empty else Any
|
||||||
|
)
|
||||||
|
|
||||||
|
if param.default is param.empty:
|
||||||
|
fields[param_name] = (annotation, ...)
|
||||||
|
else:
|
||||||
|
fields[param_name] = (annotation, param.default)
|
||||||
|
|
||||||
class_name = "".join(tool_name.split()).title()
|
class_name = "".join(tool_name.split()).title()
|
||||||
tool_args_schema = cast(
|
args_schema = create_model(class_name, **fields)
|
||||||
type[PydanticBaseModel],
|
|
||||||
type(
|
|
||||||
class_name,
|
|
||||||
(PydanticBaseModel,),
|
|
||||||
{
|
|
||||||
"__annotations__": {
|
|
||||||
k: v for k, v in func_annotations.items() if k != "return"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return Tool(
|
return Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
description=f.__doc__,
|
description=f.__doc__,
|
||||||
func=f,
|
func=f,
|
||||||
args_schema=tool_args_schema,
|
args_schema=args_schema,
|
||||||
result_as_answer=result_as_answer,
|
result_as_answer=result_as_answer,
|
||||||
max_usage_count=max_usage_count,
|
max_usage_count=max_usage_count,
|
||||||
|
current_usage_count=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _make_tool
|
return _make_tool
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from copy import deepcopy
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Final, TypedDict
|
from typing import TYPE_CHECKING, Any, Final, TypedDict
|
||||||
@@ -13,6 +11,7 @@ from crewai.agents.agent_builder.utilities.base_output_converter import OutputCo
|
|||||||
from crewai.utilities.i18n import get_i18n
|
from crewai.utilities.i18n import get_i18n
|
||||||
from crewai.utilities.internal_instructor import InternalInstructor
|
from crewai.utilities.internal_instructor import InternalInstructor
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -421,221 +420,3 @@ def create_converter(
|
|||||||
raise Exception("No output converter found or set.")
|
raise Exception("No output converter found or set.")
|
||||||
|
|
||||||
return converter # type: ignore[no-any-return]
|
return converter # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Recursively resolve all local $refs in the given JSON Schema using $defs as the source.
|
|
||||||
|
|
||||||
This is needed because Pydantic generates $ref-based schemas that
|
|
||||||
some consumers (e.g. LLMs, tool frameworks) don't handle well.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
schema: JSON Schema dict that may contain "$refs" and "$defs".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A new schema dictionary with all local $refs replaced by their definitions.
|
|
||||||
"""
|
|
||||||
defs = schema.get("$defs", {})
|
|
||||||
schema_copy = deepcopy(schema)
|
|
||||||
|
|
||||||
def _resolve(node: Any) -> Any:
|
|
||||||
if isinstance(node, dict):
|
|
||||||
ref = node.get("$ref")
|
|
||||||
if isinstance(ref, str) and ref.startswith("#/$defs/"):
|
|
||||||
def_name = ref.replace("#/$defs/", "")
|
|
||||||
if def_name in defs:
|
|
||||||
return _resolve(deepcopy(defs[def_name]))
|
|
||||||
raise KeyError(f"Definition '{def_name}' not found in $defs.")
|
|
||||||
return {k: _resolve(v) for k, v in node.items()}
|
|
||||||
|
|
||||||
if isinstance(node, list):
|
|
||||||
return [_resolve(i) for i in node]
|
|
||||||
|
|
||||||
return node
|
|
||||||
|
|
||||||
return _resolve(schema_copy) # type: ignore[no-any-return]
|
|
||||||
|
|
||||||
|
|
||||||
def add_key_in_dict_recursively(
|
|
||||||
d: dict[str, Any], key: str, value: Any, criteria: Callable[[dict[str, Any]], bool]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Recursively adds a key/value pair to all nested dicts matching `criteria`."""
|
|
||||||
if isinstance(d, dict):
|
|
||||||
if criteria(d) and key not in d:
|
|
||||||
d[key] = value
|
|
||||||
for v in d.values():
|
|
||||||
add_key_in_dict_recursively(v, key, value, criteria)
|
|
||||||
elif isinstance(d, list):
|
|
||||||
for i in d:
|
|
||||||
add_key_in_dict_recursively(i, key, value, criteria)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def fix_discriminator_mappings(schema: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Replace '#/$defs/...' references in discriminator.mapping with just the model name."""
|
|
||||||
output = schema.get("properties", {}).get("output")
|
|
||||||
if not output:
|
|
||||||
return schema
|
|
||||||
|
|
||||||
disc = output.get("discriminator")
|
|
||||||
if not disc or "mapping" not in disc:
|
|
||||||
return schema
|
|
||||||
|
|
||||||
disc["mapping"] = {k: v.split("/")[-1] for k, v in disc["mapping"].items()}
|
|
||||||
return schema
|
|
||||||
|
|
||||||
|
|
||||||
def add_const_to_oneof_variants(schema: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Add const fields to oneOf variants for discriminated unions.
|
|
||||||
|
|
||||||
The json_schema_to_pydantic library requires each oneOf variant to have
|
|
||||||
a const field for the discriminator property. This function adds those
|
|
||||||
const fields based on the discriminator mapping.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
schema: JSON Schema dict that may contain discriminated unions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified schema with const fields added to oneOf variants
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _process_oneof(node: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Process a single node that might contain a oneOf with discriminator."""
|
|
||||||
if not isinstance(node, dict):
|
|
||||||
return node
|
|
||||||
|
|
||||||
if "oneOf" in node and "discriminator" in node:
|
|
||||||
discriminator = node["discriminator"]
|
|
||||||
property_name = discriminator.get("propertyName")
|
|
||||||
mapping = discriminator.get("mapping", {})
|
|
||||||
|
|
||||||
if property_name and mapping:
|
|
||||||
one_of_variants = node.get("oneOf", [])
|
|
||||||
|
|
||||||
for variant in one_of_variants:
|
|
||||||
if isinstance(variant, dict) and "properties" in variant:
|
|
||||||
variant_title = variant.get("title", "")
|
|
||||||
|
|
||||||
matched_disc_value = None
|
|
||||||
for disc_value, schema_name in mapping.items():
|
|
||||||
if variant_title == schema_name or variant_title.endswith(
|
|
||||||
schema_name
|
|
||||||
):
|
|
||||||
matched_disc_value = disc_value
|
|
||||||
break
|
|
||||||
|
|
||||||
if matched_disc_value is not None:
|
|
||||||
props = variant["properties"]
|
|
||||||
if property_name in props:
|
|
||||||
props[property_name]["const"] = matched_disc_value
|
|
||||||
|
|
||||||
for key, value in node.items():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
node[key] = _process_oneof(value)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
node[key] = [
|
|
||||||
_process_oneof(item) if isinstance(item, dict) else item
|
|
||||||
for item in value
|
|
||||||
]
|
|
||||||
|
|
||||||
return node
|
|
||||||
|
|
||||||
return _process_oneof(deepcopy(schema))
|
|
||||||
|
|
||||||
|
|
||||||
def convert_oneof_to_anyof(schema: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Convert oneOf to anyOf for OpenAI compatibility.
|
|
||||||
|
|
||||||
OpenAI's Structured Outputs support anyOf better than oneOf.
|
|
||||||
This recursively converts all oneOf occurrences to anyOf.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
schema: JSON schema dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified schema with anyOf instead of oneOf.
|
|
||||||
"""
|
|
||||||
if isinstance(schema, dict):
|
|
||||||
if "oneOf" in schema:
|
|
||||||
schema["anyOf"] = schema.pop("oneOf")
|
|
||||||
|
|
||||||
for value in schema.values():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
convert_oneof_to_anyof(value)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
convert_oneof_to_anyof(item)
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Ensure all properties are in the required array for OpenAI strict mode.
|
|
||||||
|
|
||||||
OpenAI's strict structured outputs require all properties to be listed
|
|
||||||
in the required array. This recursively updates all objects to include
|
|
||||||
all their properties in required.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
schema: JSON schema dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified schema with all properties marked as required.
|
|
||||||
"""
|
|
||||||
if isinstance(schema, dict):
|
|
||||||
if schema.get("type") == "object" and "properties" in schema:
|
|
||||||
properties = schema["properties"]
|
|
||||||
if properties:
|
|
||||||
schema["required"] = list(properties.keys())
|
|
||||||
|
|
||||||
for value in schema.values():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
ensure_all_properties_required(value)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
for item in value:
|
|
||||||
if isinstance(item, dict):
|
|
||||||
ensure_all_properties_required(item)
|
|
||||||
|
|
||||||
return schema
|
|
||||||
|
|
||||||
|
|
||||||
def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
|
||||||
"""Generate JSON schema description of a Pydantic model.
|
|
||||||
|
|
||||||
This function takes a Pydantic model class and returns its JSON schema,
|
|
||||||
which includes full type information, discriminators, and all metadata.
|
|
||||||
The schema is dereferenced to inline all $ref references for better LLM understanding.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: A Pydantic model class.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A JSON schema dictionary representation of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
|
|
||||||
|
|
||||||
json_schema = add_key_in_dict_recursively(
|
|
||||||
json_schema,
|
|
||||||
key="additionalProperties",
|
|
||||||
value=False,
|
|
||||||
criteria=lambda d: d.get("type") == "object"
|
|
||||||
and "additionalProperties" not in d,
|
|
||||||
)
|
|
||||||
|
|
||||||
json_schema = resolve_refs(json_schema)
|
|
||||||
|
|
||||||
json_schema.pop("$defs", None)
|
|
||||||
json_schema = fix_discriminator_mappings(json_schema)
|
|
||||||
json_schema = convert_oneof_to_anyof(json_schema)
|
|
||||||
json_schema = ensure_all_properties_required(json_schema)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "json_schema",
|
|
||||||
"json_schema": {
|
|
||||||
"name": model.__name__,
|
|
||||||
"strict": True,
|
|
||||||
"schema": json_schema,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, cast
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.task_events import TaskEvaluationEvent
|
from crewai.events.types.task_events import TaskEvaluationEvent
|
||||||
from crewai.llm import LLM
|
|
||||||
from crewai.utilities.converter import Converter
|
from crewai.utilities.converter import Converter
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
from crewai.utilities.i18n import get_i18n
|
||||||
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||||
from crewai.utilities.training_converter import TrainingConverter
|
from crewai.utilities.training_converter import TrainingConverter
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +63,7 @@ class TaskEvaluator:
|
|||||||
Args:
|
Args:
|
||||||
original_agent: The agent to evaluate.
|
original_agent: The agent to evaluate.
|
||||||
"""
|
"""
|
||||||
self.llm = cast(LLM, original_agent.llm)
|
self.llm = original_agent.llm
|
||||||
self.original_agent = original_agent
|
self.original_agent = original_agent
|
||||||
|
|
||||||
def evaluate(self, task: Task, output: str) -> TaskEvaluation:
|
def evaluate(self, task: Task, output: str) -> TaskEvaluation:
|
||||||
@@ -79,7 +80,8 @@ class TaskEvaluator:
|
|||||||
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
||||||
"""
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self, TaskEvaluationEvent(evaluation_type="task_evaluation", task=task)
|
self,
|
||||||
|
TaskEvaluationEvent(evaluation_type="task_evaluation", task=task), # type: ignore[no-untyped-call]
|
||||||
)
|
)
|
||||||
evaluation_query = (
|
evaluation_query = (
|
||||||
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
f"Assess the quality of the task completed based on the description, expected output, and actual results.\n\n"
|
||||||
@@ -94,9 +96,14 @@ class TaskEvaluator:
|
|||||||
|
|
||||||
instructions = "Convert all responses into valid JSON output."
|
instructions = "Convert all responses into valid JSON output."
|
||||||
|
|
||||||
if not self.llm.supports_function_calling():
|
if not self.llm.supports_function_calling(): # type: ignore[union-attr]
|
||||||
model_schema = PydanticSchemaParser(model=TaskEvaluation).get_schema()
|
schema_dict = generate_model_description(TaskEvaluation)
|
||||||
instructions = f"{instructions}\n\nReturn only valid JSON with the following schema:\n```json\n{model_schema}\n```"
|
output_schema: str = (
|
||||||
|
get_i18n()
|
||||||
|
.slice("formatted_task_instructions")
|
||||||
|
.format(output_format=json.dumps(schema_dict, indent=2))
|
||||||
|
)
|
||||||
|
instructions = f"{instructions}\n\n{output_schema}"
|
||||||
|
|
||||||
converter = Converter(
|
converter = Converter(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
@@ -108,7 +115,7 @@ class TaskEvaluator:
|
|||||||
return cast(TaskEvaluation, converter.to_pydantic())
|
return cast(TaskEvaluation, converter.to_pydantic())
|
||||||
|
|
||||||
def evaluate_training_data(
|
def evaluate_training_data(
|
||||||
self, training_data: dict, agent_id: str
|
self, training_data: dict[str, Any], agent_id: str
|
||||||
) -> TrainingTaskEvaluation:
|
) -> TrainingTaskEvaluation:
|
||||||
"""
|
"""
|
||||||
Evaluate the training data based on the llm output, human feedback, and improved output.
|
Evaluate the training data based on the llm output, human feedback, and improved output.
|
||||||
@@ -121,7 +128,8 @@ class TaskEvaluator:
|
|||||||
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
- Investigate the Converter.to_pydantic signature, returns BaseModel strictly?
|
||||||
"""
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self, TaskEvaluationEvent(evaluation_type="training_data_evaluation")
|
self,
|
||||||
|
TaskEvaluationEvent(evaluation_type="training_data_evaluation"), # type: ignore[no-untyped-call]
|
||||||
)
|
)
|
||||||
|
|
||||||
output_training_data = training_data[agent_id]
|
output_training_data = training_data[agent_id]
|
||||||
@@ -164,11 +172,14 @@ class TaskEvaluator:
|
|||||||
)
|
)
|
||||||
instructions = "I'm gonna convert this raw text into valid JSON."
|
instructions = "I'm gonna convert this raw text into valid JSON."
|
||||||
|
|
||||||
if not self.llm.supports_function_calling():
|
if not self.llm.supports_function_calling(): # type: ignore[union-attr]
|
||||||
model_schema = PydanticSchemaParser(
|
schema_dict = generate_model_description(TrainingTaskEvaluation)
|
||||||
model=TrainingTaskEvaluation
|
output_schema: str = (
|
||||||
).get_schema()
|
get_i18n()
|
||||||
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
.slice("formatted_task_instructions")
|
||||||
|
.format(output_format=json.dumps(schema_dict, indent=2))
|
||||||
|
)
|
||||||
|
instructions = f"{instructions}\n\n{output_schema}"
|
||||||
|
|
||||||
converter = TrainingConverter(
|
converter = TrainingConverter(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
|
|||||||
@@ -1,103 +0,0 @@
|
|||||||
from typing import Any, Union, get_args, get_origin
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class PydanticSchemaParser(BaseModel):
|
|
||||||
model: type[BaseModel] = Field(..., description="The Pydantic model to parse.")
|
|
||||||
|
|
||||||
def get_schema(self) -> str:
|
|
||||||
"""Public method to get the schema of a Pydantic model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
String representation of the model schema.
|
|
||||||
"""
|
|
||||||
return "{\n" + self._get_model_schema(self.model) + "\n}"
|
|
||||||
|
|
||||||
def _get_model_schema(self, model: type[BaseModel], depth: int = 0) -> str:
|
|
||||||
"""Recursively get the schema of a Pydantic model, handling nested models and lists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The Pydantic model to process.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the model schema.
|
|
||||||
"""
|
|
||||||
indent: str = " " * 4 * depth
|
|
||||||
lines: list[str] = [
|
|
||||||
f"{indent} {field_name}: {self._get_field_type_for_annotation(field.annotation, depth + 1)}"
|
|
||||||
for field_name, field in model.model_fields.items()
|
|
||||||
]
|
|
||||||
return ",\n".join(lines)
|
|
||||||
|
|
||||||
def _format_list_type(self, list_item_type: Any, depth: int) -> str:
|
|
||||||
"""Format a List type, handling nested models if necessary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
list_item_type: The type of items in the list.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the List type.
|
|
||||||
"""
|
|
||||||
if isinstance(list_item_type, type) and issubclass(list_item_type, BaseModel):
|
|
||||||
nested_schema = self._get_model_schema(list_item_type, depth + 1)
|
|
||||||
nested_indent = " " * 4 * depth
|
|
||||||
return f"List[\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}\n{nested_indent}]"
|
|
||||||
return f"List[{list_item_type.__name__}]"
|
|
||||||
|
|
||||||
def _format_union_type(self, field_type: Any, depth: int) -> str:
|
|
||||||
"""Format a Union type, handling Optional and nested types.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
field_type: The Union type to format.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the Union type.
|
|
||||||
"""
|
|
||||||
args = get_args(field_type)
|
|
||||||
if type(None) in args:
|
|
||||||
# It's an Optional type
|
|
||||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
|
||||||
if len(non_none_args) == 1:
|
|
||||||
inner_type = self._get_field_type_for_annotation(
|
|
||||||
non_none_args[0], depth
|
|
||||||
)
|
|
||||||
return f"Optional[{inner_type}]"
|
|
||||||
# Union with None and multiple other types
|
|
||||||
inner_types = ", ".join(
|
|
||||||
self._get_field_type_for_annotation(arg, depth) for arg in non_none_args
|
|
||||||
)
|
|
||||||
return f"Optional[Union[{inner_types}]]"
|
|
||||||
# General Union type
|
|
||||||
inner_types = ", ".join(
|
|
||||||
self._get_field_type_for_annotation(arg, depth) for arg in args
|
|
||||||
)
|
|
||||||
return f"Union[{inner_types}]"
|
|
||||||
|
|
||||||
def _get_field_type_for_annotation(self, annotation: Any, depth: int) -> str:
|
|
||||||
"""Recursively get the string representation of a field's type annotation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
annotation: The type annotation to process.
|
|
||||||
depth: The current depth of recursion for indentation purposes.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A string representation of the type annotation.
|
|
||||||
"""
|
|
||||||
origin: Any = get_origin(annotation)
|
|
||||||
if origin is list:
|
|
||||||
list_item_type = get_args(annotation)[0]
|
|
||||||
return self._format_list_type(list_item_type, depth)
|
|
||||||
if origin is dict:
|
|
||||||
key_type, value_type = get_args(annotation)
|
|
||||||
return f"Dict[{key_type.__name__}, {value_type.__name__}]"
|
|
||||||
if origin is Union:
|
|
||||||
return self._format_union_type(annotation, depth)
|
|
||||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
|
||||||
nested_schema = self._get_model_schema(annotation, depth)
|
|
||||||
nested_indent = " " * 4 * depth
|
|
||||||
return f"{annotation.__name__}\n{nested_indent}{{\n{nested_schema}\n{nested_indent}}}"
|
|
||||||
return annotation.__name__
|
|
||||||
245
lib/crewai/src/crewai/utilities/pydantic_schema_utils.py
Normal file
245
lib/crewai/src/crewai/utilities/pydantic_schema_utils.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Utilities for generating JSON schemas from Pydantic models.
|
||||||
|
|
||||||
|
This module provides functions for converting Pydantic models to JSON schemas
|
||||||
|
suitable for use with LLMs and tool definitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Recursively resolve all local $refs in the given JSON Schema using $defs as the source.
|
||||||
|
|
||||||
|
This is needed because Pydantic generates $ref-based schemas that
|
||||||
|
some consumers (e.g. LLMs, tool frameworks) don't handle well.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON Schema dict that may contain "$refs" and "$defs".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new schema dictionary with all local $refs replaced by their definitions.
|
||||||
|
"""
|
||||||
|
defs = schema.get("$defs", {})
|
||||||
|
schema_copy = deepcopy(schema)
|
||||||
|
|
||||||
|
def _resolve(node: Any) -> Any:
|
||||||
|
if isinstance(node, dict):
|
||||||
|
ref = node.get("$ref")
|
||||||
|
if isinstance(ref, str) and ref.startswith("#/$defs/"):
|
||||||
|
def_name = ref.replace("#/$defs/", "")
|
||||||
|
if def_name in defs:
|
||||||
|
return _resolve(deepcopy(defs[def_name]))
|
||||||
|
raise KeyError(f"Definition '{def_name}' not found in $defs.")
|
||||||
|
return {k: _resolve(v) for k, v in node.items()}
|
||||||
|
|
||||||
|
if isinstance(node, list):
|
||||||
|
return [_resolve(i) for i in node]
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
return _resolve(schema_copy) # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
|
def add_key_in_dict_recursively(
|
||||||
|
d: dict[str, Any], key: str, value: Any, criteria: Callable[[dict[str, Any]], bool]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Recursively adds a key/value pair to all nested dicts matching `criteria`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d: The dictionary to modify.
|
||||||
|
key: The key to add.
|
||||||
|
value: The value to add.
|
||||||
|
criteria: A function that returns True for dicts that should receive the key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The modified dictionary.
|
||||||
|
"""
|
||||||
|
if isinstance(d, dict):
|
||||||
|
if criteria(d) and key not in d:
|
||||||
|
d[key] = value
|
||||||
|
for v in d.values():
|
||||||
|
add_key_in_dict_recursively(v, key, value, criteria)
|
||||||
|
elif isinstance(d, list):
|
||||||
|
for i in d:
|
||||||
|
add_key_in_dict_recursively(i, key, value, criteria)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def fix_discriminator_mappings(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Replace '#/$defs/...' references in discriminator.mapping with just the model name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON schema dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified schema with fixed discriminator mappings.
|
||||||
|
"""
|
||||||
|
output = schema.get("properties", {}).get("output")
|
||||||
|
if not output:
|
||||||
|
return schema
|
||||||
|
|
||||||
|
disc = output.get("discriminator")
|
||||||
|
if not disc or "mapping" not in disc:
|
||||||
|
return schema
|
||||||
|
|
||||||
|
disc["mapping"] = {k: v.split("/")[-1] for k, v in disc["mapping"].items()}
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def add_const_to_oneof_variants(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Add const fields to oneOf variants for discriminated unions.
|
||||||
|
|
||||||
|
The json_schema_to_pydantic library requires each oneOf variant to have
|
||||||
|
a const field for the discriminator property. This function adds those
|
||||||
|
const fields based on the discriminator mapping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON Schema dict that may contain discriminated unions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified schema with const fields added to oneOf variants
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _process_oneof(node: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Process a single node that might contain a oneOf with discriminator."""
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
return node
|
||||||
|
|
||||||
|
if "oneOf" in node and "discriminator" in node:
|
||||||
|
discriminator = node["discriminator"]
|
||||||
|
property_name = discriminator.get("propertyName")
|
||||||
|
mapping = discriminator.get("mapping", {})
|
||||||
|
|
||||||
|
if property_name and mapping:
|
||||||
|
one_of_variants = node.get("oneOf", [])
|
||||||
|
|
||||||
|
for variant in one_of_variants:
|
||||||
|
if isinstance(variant, dict) and "properties" in variant:
|
||||||
|
variant_title = variant.get("title", "")
|
||||||
|
|
||||||
|
matched_disc_value = None
|
||||||
|
for disc_value, schema_name in mapping.items():
|
||||||
|
if variant_title == schema_name or variant_title.endswith(
|
||||||
|
schema_name
|
||||||
|
):
|
||||||
|
matched_disc_value = disc_value
|
||||||
|
break
|
||||||
|
|
||||||
|
if matched_disc_value is not None:
|
||||||
|
props = variant["properties"]
|
||||||
|
if property_name in props:
|
||||||
|
props[property_name]["const"] = matched_disc_value
|
||||||
|
|
||||||
|
for key, value in node.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
node[key] = _process_oneof(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
node[key] = [
|
||||||
|
_process_oneof(item) if isinstance(item, dict) else item
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
return _process_oneof(deepcopy(schema))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_oneof_to_anyof(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Convert oneOf to anyOf for OpenAI compatibility.
|
||||||
|
|
||||||
|
OpenAI's Structured Outputs support anyOf better than oneOf.
|
||||||
|
This recursively converts all oneOf occurrences to anyOf.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON schema dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified schema with anyOf instead of oneOf.
|
||||||
|
"""
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
if "oneOf" in schema:
|
||||||
|
schema["anyOf"] = schema.pop("oneOf")
|
||||||
|
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
convert_oneof_to_anyof(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
convert_oneof_to_anyof(item)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_all_properties_required(schema: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Ensure all properties are in the required array for OpenAI strict mode.
|
||||||
|
|
||||||
|
OpenAI's strict structured outputs require all properties to be listed
|
||||||
|
in the required array. This recursively updates all objects to include
|
||||||
|
all their properties in required.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema: JSON schema dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified schema with all properties marked as required.
|
||||||
|
"""
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
if schema.get("type") == "object" and "properties" in schema:
|
||||||
|
properties = schema["properties"]
|
||||||
|
if properties:
|
||||||
|
schema["required"] = list(properties.keys())
|
||||||
|
|
||||||
|
for value in schema.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
ensure_all_properties_required(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
ensure_all_properties_required(item)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def generate_model_description(model: type[BaseModel]) -> dict[str, Any]:
|
||||||
|
"""Generate JSON schema description of a Pydantic model.
|
||||||
|
|
||||||
|
This function takes a Pydantic model class and returns its JSON schema,
|
||||||
|
which includes full type information, discriminators, and all metadata.
|
||||||
|
The schema is dereferenced to inline all $ref references for better LLM understanding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A Pydantic model class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON schema dictionary representation of the model.
|
||||||
|
"""
|
||||||
|
json_schema = model.model_json_schema(ref_template="#/$defs/{model}")
|
||||||
|
|
||||||
|
json_schema = add_key_in_dict_recursively(
|
||||||
|
json_schema,
|
||||||
|
key="additionalProperties",
|
||||||
|
value=False,
|
||||||
|
criteria=lambda d: d.get("type") == "object"
|
||||||
|
and "additionalProperties" not in d,
|
||||||
|
)
|
||||||
|
|
||||||
|
json_schema = resolve_refs(json_schema)
|
||||||
|
|
||||||
|
json_schema.pop("$defs", None)
|
||||||
|
json_schema = fix_discriminator_mappings(json_schema)
|
||||||
|
json_schema = convert_oneof_to_anyof(json_schema)
|
||||||
|
json_schema = ensure_all_properties_required(json_schema)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": model.__name__,
|
||||||
|
"strict": True,
|
||||||
|
"schema": json_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -1727,3 +1727,24 @@ def test_task_output_includes_messages():
|
|||||||
assert hasattr(task2_output, "messages")
|
assert hasattr(task2_output, "messages")
|
||||||
assert isinstance(task2_output.messages, list)
|
assert isinstance(task2_output.messages, list)
|
||||||
assert len(task2_output.messages) > 0
|
assert len(task2_output.messages) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_execution_fails():
|
||||||
|
researcher = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Make the best research and analysis on content about AI and AI agents",
|
||||||
|
backstory="You're an expert researcher, specialized in technology, software engineering, AI and startups. You work as a freelancer and is now working on doing research and analysis for a new customer.",
|
||||||
|
allow_delegation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
|
||||||
|
expected_output="Bullet point list of 5 interesting ideas.",
|
||||||
|
async_execution=True,
|
||||||
|
agent=researcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(Task, "_execute_core", side_effect=RuntimeError("boom!")):
|
||||||
|
with pytest.raises(RuntimeError, match="boom!"):
|
||||||
|
execution = task.execute_async(agent=researcher)
|
||||||
|
execution.result()
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ def test_creating_a_tool_using_annotation():
|
|||||||
|
|
||||||
# Assert all the right attributes were defined
|
# Assert all the right attributes were defined
|
||||||
assert my_tool.name == "Name of my tool"
|
assert my_tool.name == "Name of my tool"
|
||||||
assert (
|
assert "Tool Name: Name of my tool" in my_tool.description
|
||||||
my_tool.description
|
assert "Tool Arguments:" in my_tool.description
|
||||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
|
assert '"question"' in my_tool.description
|
||||||
)
|
assert '"type": "string"' in my_tool.description
|
||||||
|
assert "Tool Description: Clear description for what this tool is useful for" in my_tool.description
|
||||||
assert my_tool.args_schema.model_json_schema()["properties"] == {
|
assert my_tool.args_schema.model_json_schema()["properties"] == {
|
||||||
"question": {"title": "Question", "type": "string"}
|
"question": {"title": "Question", "type": "string"}
|
||||||
}
|
}
|
||||||
@@ -31,10 +32,9 @@ def test_creating_a_tool_using_annotation():
|
|||||||
converted_tool = my_tool.to_structured_tool()
|
converted_tool = my_tool.to_structured_tool()
|
||||||
assert converted_tool.name == "Name of my tool"
|
assert converted_tool.name == "Name of my tool"
|
||||||
|
|
||||||
assert (
|
assert "Tool Name: Name of my tool" in converted_tool.description
|
||||||
converted_tool.description
|
assert "Tool Arguments:" in converted_tool.description
|
||||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
|
assert '"question"' in converted_tool.description
|
||||||
)
|
|
||||||
assert converted_tool.args_schema.model_json_schema()["properties"] == {
|
assert converted_tool.args_schema.model_json_schema()["properties"] == {
|
||||||
"question": {"title": "Question", "type": "string"}
|
"question": {"title": "Question", "type": "string"}
|
||||||
}
|
}
|
||||||
@@ -56,10 +56,11 @@ def test_creating_a_tool_using_baseclass():
|
|||||||
# Assert all the right attributes were defined
|
# Assert all the right attributes were defined
|
||||||
assert my_tool.name == "Name of my tool"
|
assert my_tool.name == "Name of my tool"
|
||||||
|
|
||||||
assert (
|
assert "Tool Name: Name of my tool" in my_tool.description
|
||||||
my_tool.description
|
assert "Tool Arguments:" in my_tool.description
|
||||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
|
assert '"question"' in my_tool.description
|
||||||
)
|
assert '"type": "string"' in my_tool.description
|
||||||
|
assert "Tool Description: Clear description for what this tool is useful for" in my_tool.description
|
||||||
assert my_tool.args_schema.model_json_schema()["properties"] == {
|
assert my_tool.args_schema.model_json_schema()["properties"] == {
|
||||||
"question": {"title": "Question", "type": "string"}
|
"question": {"title": "Question", "type": "string"}
|
||||||
}
|
}
|
||||||
@@ -68,10 +69,9 @@ def test_creating_a_tool_using_baseclass():
|
|||||||
converted_tool = my_tool.to_structured_tool()
|
converted_tool = my_tool.to_structured_tool()
|
||||||
assert converted_tool.name == "Name of my tool"
|
assert converted_tool.name == "Name of my tool"
|
||||||
|
|
||||||
assert (
|
assert "Tool Name: Name of my tool" in converted_tool.description
|
||||||
converted_tool.description
|
assert "Tool Arguments:" in converted_tool.description
|
||||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, your agent will need this information to use it."
|
assert '"question"' in converted_tool.description
|
||||||
)
|
|
||||||
assert converted_tool.args_schema.model_json_schema()["properties"] == {
|
assert converted_tool.args_schema.model_json_schema()["properties"] == {
|
||||||
"question": {"title": "Question", "type": "string"}
|
"question": {"title": "Question", "type": "string"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,25 +107,20 @@ def test_tool_usage_render():
|
|||||||
|
|
||||||
rendered = tool_usage._render()
|
rendered = tool_usage._render()
|
||||||
|
|
||||||
# Updated checks to match the actual output
|
# Check that the rendered output contains the expected tool information
|
||||||
assert "Tool Name: Random Number Generator" in rendered
|
assert "Tool Name: Random Number Generator" in rendered
|
||||||
assert "Tool Arguments:" in rendered
|
assert "Tool Arguments:" in rendered
|
||||||
assert (
|
|
||||||
"'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}"
|
|
||||||
in rendered
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
"'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}"
|
|
||||||
in rendered
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
"Tool Description: Generates a random number within a specified range"
|
"Tool Description: Generates a random number within a specified range"
|
||||||
in rendered
|
in rendered
|
||||||
)
|
)
|
||||||
assert (
|
|
||||||
"Tool Name: Random Number Generator\nTool Arguments: {'min_value': {'description': 'The minimum value of the range (inclusive)', 'type': 'int'}, 'max_value': {'description': 'The maximum value of the range (inclusive)', 'type': 'int'}}\nTool Description: Generates a random number within a specified range"
|
# Check that the JSON schema format is used (proper JSON schema types)
|
||||||
in rendered
|
assert '"min_value"' in rendered
|
||||||
)
|
assert '"max_value"' in rendered
|
||||||
|
assert '"type": "integer"' in rendered
|
||||||
|
assert '"description": "The minimum value of the range (inclusive)"' in rendered
|
||||||
|
assert '"description": "The maximum value of the range (inclusive)"' in rendered
|
||||||
|
|
||||||
|
|
||||||
def test_validate_tool_input_booleans_and_none():
|
def test_validate_tool_input_booleans_and_none():
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from unittest import mock
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from crewai.utilities.converter import ConverterError
|
from crewai.utilities.converter import ConverterError
|
||||||
@@ -44,26 +43,26 @@ def test_evaluate_training_data(converter_mock):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result == function_return_value
|
assert result == function_return_value
|
||||||
converter_mock.assert_has_calls(
|
|
||||||
[
|
# Verify the converter was called with correct arguments
|
||||||
mock.call(
|
converter_mock.assert_called_once()
|
||||||
llm=original_agent.llm,
|
call_kwargs = converter_mock.call_args.kwargs
|
||||||
text="Assess the quality of the training data based on the llm output, human feedback , and llm "
|
|
||||||
"output improved result.\n\nIteration: data1\nInitial Output:\nInitial output 1\n\nHuman Feedback:\nHuman feedback "
|
assert call_kwargs["llm"] == original_agent.llm
|
||||||
"1\n\nImproved Output:\nImproved output 1\n\n------------------------------------------------\n\nIteration: data2\nInitial Output:\nInitial output 2\n\nHuman "
|
assert call_kwargs["model"] == TrainingTaskEvaluation
|
||||||
"Feedback:\nHuman feedback 2\n\nImproved Output:\nImproved output 2\n\n------------------------------------------------\n\nPlease provide:\n- Provide "
|
assert "Iteration: data1" in call_kwargs["text"]
|
||||||
"a list of clear, actionable instructions derived from the Human Feedbacks to enhance the Agent's "
|
assert "Iteration: data2" in call_kwargs["text"]
|
||||||
"performance. Analyze the differences between Initial Outputs and Improved Outputs to generate specific "
|
|
||||||
"action items for future tasks. Ensure all key and specificpoints from the human feedback are "
|
instructions = call_kwargs["instructions"]
|
||||||
"incorporated into these instructions.\n- A score from 0 to 10 evaluating on completion, quality, and "
|
assert "I'm gonna convert this raw text into valid JSON." in instructions
|
||||||
"overall performance from the improved output to the initial output based on the human feedback\n",
|
assert "OpenAPI schema" in instructions
|
||||||
model=TrainingTaskEvaluation,
|
assert '"type": "json_schema"' in instructions
|
||||||
instructions="I'm gonna convert this raw text into valid JSON.\n\nThe json should have the "
|
assert '"name": "TrainingTaskEvaluation"' in instructions
|
||||||
"following structure, with the following keys:\n{\n suggestions: List[str],\n quality: float,\n final_summary: str\n}",
|
assert '"suggestions"' in instructions
|
||||||
),
|
assert '"quality"' in instructions
|
||||||
mock.call().to_pydantic(),
|
assert '"final_summary"' in instructions
|
||||||
]
|
|
||||||
)
|
converter_mock.return_value.to_pydantic.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.utilities.converter.Converter.to_pydantic")
|
@patch("crewai.utilities.converter.Converter.to_pydantic")
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from crewai.utilities.converter import (
|
|||||||
handle_partial_json,
|
handle_partial_json,
|
||||||
validate_model,
|
validate_model,
|
||||||
)
|
)
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple_model():
|
|
||||||
class SimpleModel(BaseModel):
|
|
||||||
field1: int
|
|
||||||
field2: str
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=SimpleModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
field1: int,
|
|
||||||
field2: str
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested_model():
|
|
||||||
class NestedModel(BaseModel):
|
|
||||||
nested_field: int
|
|
||||||
|
|
||||||
class ParentModel(BaseModel):
|
|
||||||
parent_field: str
|
|
||||||
nested: NestedModel
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=ParentModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
parent_field: str,
|
|
||||||
nested: NestedModel
|
|
||||||
{
|
|
||||||
nested_field: int
|
|
||||||
}
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_list():
|
|
||||||
class ListModel(BaseModel):
|
|
||||||
list_field: List[int]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=ListModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
list_field: List[int]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_optional_field():
|
|
||||||
class OptionalModel(BaseModel):
|
|
||||||
optional_field: Optional[str]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=OptionalModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
optional_field: Optional[str]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_union():
|
|
||||||
class UnionModel(BaseModel):
|
|
||||||
union_field: Union[int, str]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=UnionModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
union_field: Union[int, str]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_with_dict():
|
|
||||||
class DictModel(BaseModel):
|
|
||||||
dict_field: Dict[str, int]
|
|
||||||
|
|
||||||
parser = PydanticSchemaParser(model=DictModel)
|
|
||||||
schema = parser.get_schema()
|
|
||||||
|
|
||||||
expected_schema = """{
|
|
||||||
dict_field: Dict[str, int]
|
|
||||||
}"""
|
|
||||||
assert schema.strip() == expected_schema.strip()
|
|
||||||
Reference in New Issue
Block a user