Vision Tool Improvement

This commit is contained in:
Mike Plachta
2025-01-03 09:33:59 -08:00
parent 601abb2bc3
commit 66dee007b7

View File

@@ -1,18 +1,31 @@
import base64 import base64
from typing import Type from typing import Type, Optional
from pathlib import Path
import requests from crewai.tools import BaseTool
from openai import OpenAI from openai import OpenAI
from pydantic import BaseModel from pydantic import BaseModel, validator
from crewai_tools.tools.base_tool import BaseTool
class ImagePromptSchema(BaseModel): class ImagePromptSchema(BaseModel):
"""Input for Vision Tool.""" """Input for Vision Tool."""
image_path_url: str = "The image path or URL." image_path_url: str = "The image path or URL."
@validator("image_path_url")
def validate_image_path_url(cls, v: str) -> str:
if v.startswith("http"):
return v
path = Path(v)
if not path.exists():
raise ValueError(f"Image file does not exist: {v}")
# Validate supported formats
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if path.suffix.lower() not in valid_extensions:
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
return v
class VisionTool(BaseTool): class VisionTool(BaseTool):
name: str = "Vision Tool" name: str = "Vision Tool"
@@ -20,75 +33,55 @@ class VisionTool(BaseTool):
"This tool uses OpenAI's Vision API to describe the contents of an image." "This tool uses OpenAI's Vision API to describe the contents of an image."
) )
args_schema: Type[BaseModel] = ImagePromptSchema args_schema: Type[BaseModel] = ImagePromptSchema
_client: Optional[OpenAI] = None
def _run_web_hosted_images(self, client, image_path_url: str) -> str: @property
response = client.chat.completions.create( def client(self) -> OpenAI:
model="gpt-4o-mini", """Cached OpenAI client instance."""
messages=[ if self._client is None:
{ self._client = OpenAI()
"role": "user", return self._client
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_path_url},
},
],
}
],
max_tokens=300,
)
return response.choices[0].message.content
def _run_local_images(self, client, image_path_url: str) -> str:
base64_image = self._encode_image(image_path_url)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {client.api_key}",
}
payload = {
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
)
return response.json()["choices"][0]["message"]["content"]
def _run(self, **kwargs) -> str: def _run(self, **kwargs) -> str:
client = OpenAI() try:
image_path_url = kwargs.get("image_path_url")
if not image_path_url:
return "Image Path or URL is required."
# Validate input using Pydantic
ImagePromptSchema(image_path_url=image_path_url)
if image_path_url.startswith("http"):
image_data = image_path_url
else:
try:
base64_image = self._encode_image(image_path_url)
image_data = f"data:image/jpeg;base64,{base64_image}"
except Exception as e:
return f"Error processing image: {str(e)}"
image_path_url = kwargs.get("image_path_url") response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_data},
}
],
}
],
max_tokens=300,
)
if not image_path_url: return response.choices[0].message.content
return "Image Path or URL is required."
if "http" in image_path_url: except Exception as e:
image_description = self._run_web_hosted_images(client, image_path_url) return f"An error occurred: {str(e)}"
else:
image_description = self._run_local_images(client, image_path_url)
return image_description def _encode_image(self, image_path: str) -> str:
def _encode_image(self, image_path: str):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")