From 66dee007b76da208a12d57806daba6d9ddf20c69 Mon Sep 17 00:00:00 2001 From: Mike Plachta Date: Fri, 3 Jan 2025 09:33:59 -0800 Subject: [PATCH] Vision Tool Improvement --- .../tools/vision_tool/vision_tool.py | 131 +++++++++--------- 1 file changed, 62 insertions(+), 69 deletions(-) diff --git a/src/crewai_tools/tools/vision_tool/vision_tool.py b/src/crewai_tools/tools/vision_tool/vision_tool.py index 6b7a21dbd..3479cbd74 100644 --- a/src/crewai_tools/tools/vision_tool/vision_tool.py +++ b/src/crewai_tools/tools/vision_tool/vision_tool.py @@ -1,18 +1,31 @@ import base64 -from typing import Type +from typing import Type, Optional +from pathlib import Path -import requests +from crewai.tools import BaseTool from openai import OpenAI -from pydantic import BaseModel - -from crewai_tools.tools.base_tool import BaseTool +from pydantic import BaseModel, validator class ImagePromptSchema(BaseModel): """Input for Vision Tool.""" - image_path_url: str = "The image path or URL." + @validator("image_path_url") + def validate_image_path_url(cls, v: str) -> str: + if v.startswith("http"): + return v + + path = Path(v) + if not path.exists(): + raise ValueError(f"Image file does not exist: {v}") + + # Validate supported formats + valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"} + if path.suffix.lower() not in valid_extensions: + raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}") + + return v class VisionTool(BaseTool): name: str = "Vision Tool" @@ -20,75 +33,55 @@ class VisionTool(BaseTool): "This tool uses OpenAI's Vision API to describe the contents of an image." ) args_schema: Type[BaseModel] = ImagePromptSchema + _client: Optional[OpenAI] = None - def _run_web_hosted_images(self, client, image_path_url: str) -> str: - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": image_path_url}, - }, - ], - } - ], - max_tokens=300, - ) - - return response.choices[0].message.content - - def _run_local_images(self, client, image_path_url: str) -> str: - base64_image = self._encode_image(image_path_url) - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {client.api_key}", - } - - payload = { - "model": "gpt-4o-mini", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - }, - }, - ], - } - ], - "max_tokens": 300, - } - - response = requests.post( - "https://api.openai.com/v1/chat/completions", headers=headers, json=payload - ) - - return response.json()["choices"][0]["message"]["content"] + @property + def client(self) -> OpenAI: + """Cached OpenAI client instance.""" + if self._client is None: + self._client = OpenAI() + return self._client def _run(self, **kwargs) -> str: - client = OpenAI() + try: + image_path_url = kwargs.get("image_path_url") + if not image_path_url: + return "Image Path or URL is required." + + # Validate input using Pydantic + ImagePromptSchema(image_path_url=image_path_url) + + if image_path_url.startswith("http"): + image_data = image_path_url + else: + try: + base64_image = self._encode_image(image_path_url) + image_data = f"data:image/jpeg;base64,{base64_image}" + except Exception as e: + return f"Error processing image: {str(e)}" - image_path_url = kwargs.get("image_path_url") + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": image_data}, + } + ], + } + ], + max_tokens=300, + ) - if not image_path_url: - return "Image Path or URL is required." + return response.choices[0].message.content - if "http" in image_path_url: - image_description = self._run_web_hosted_images(client, image_path_url) - else: - image_description = self._run_local_images(client, image_path_url) + except Exception as e: + return f"An error occurred: {str(e)}" - return image_description - - def _encode_image(self, image_path: str): + def _encode_image(self, image_path: str) -> str: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8")