mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
git-subtree-dir: packages/tools git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
53 lines
1.3 KiB
Python
53 lines
1.3 KiB
Python
import json
|
|
from typing import List, Type
|
|
|
|
from crewai.tools import BaseTool, EnvVar
|
|
from openai import OpenAI
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class ImagePromptSchema(BaseModel):
|
|
"""Input for Dall-E Tool."""
|
|
|
|
image_description: str = Field(description="Description of the image to be generated by Dall-E.")
|
|
|
|
|
|
class DallETool(BaseTool):
|
|
name: str = "Dall-E Tool"
|
|
description: str = "Generates images using OpenAI's Dall-E model."
|
|
args_schema: Type[BaseModel] = ImagePromptSchema
|
|
|
|
model: str = "dall-e-3"
|
|
size: str = "1024x1024"
|
|
quality: str = "standard"
|
|
n: int = 1
|
|
|
|
env_vars: List[EnvVar] = [
|
|
EnvVar(name="OPENAI_API_KEY", description="API key for OpenAI services", required=True),
|
|
]
|
|
|
|
def _run(self, **kwargs) -> str:
|
|
client = OpenAI()
|
|
|
|
image_description = kwargs.get("image_description")
|
|
|
|
if not image_description:
|
|
return "Image description is required."
|
|
|
|
response = client.images.generate(
|
|
model=self.model,
|
|
prompt=image_description,
|
|
size=self.size,
|
|
quality=self.quality,
|
|
n=self.n,
|
|
)
|
|
|
|
image_data = json.dumps(
|
|
{
|
|
"image_url": response.data[0].url,
|
|
"image_description": response.data[0].revised_prompt,
|
|
}
|
|
)
|
|
|
|
return image_data
|