Merge pull request #173 from mplachta/vision-tool-improvement

Vision Tool Refactoring and Simplify the code
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-01-03 13:31:23 -05:00
committed by GitHub

View File

@@ -1,17 +1,29 @@
import base64 import base64
from typing import Type from typing import Type, Optional
from pathlib import Path
import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from openai import OpenAI from openai import OpenAI
from pydantic import BaseModel from pydantic import BaseModel, validator
class ImagePromptSchema(BaseModel): class ImagePromptSchema(BaseModel):
"""Input for Vision Tool.""" """Input for Vision Tool."""
image_path_url: str = "The image path or URL." image_path_url: str = "The image path or URL."
@validator("image_path_url")
def validate_image_path_url(cls, v: str) -> str:
if v.startswith("http"):
return v
path = Path(v)
if not path.exists():
raise ValueError(f"Image file does not exist: {v}")
# Validate supported formats
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if path.suffix.lower() not in valid_extensions:
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
return v
class VisionTool(BaseTool): class VisionTool(BaseTool):
name: str = "Vision Tool" name: str = "Vision Tool"
@@ -19,9 +31,34 @@ class VisionTool(BaseTool):
"This tool uses OpenAI's Vision API to describe the contents of an image." "This tool uses OpenAI's Vision API to describe the contents of an image."
) )
args_schema: Type[BaseModel] = ImagePromptSchema args_schema: Type[BaseModel] = ImagePromptSchema
_client: Optional[OpenAI] = None
def _run_web_hosted_images(self, client, image_path_url: str) -> str: @property
response = client.chat.completions.create( def client(self) -> OpenAI:
"""Cached OpenAI client instance."""
if self._client is None:
self._client = OpenAI()
return self._client
def _run(self, **kwargs) -> str:
try:
image_path_url = kwargs.get("image_path_url")
if not image_path_url:
return "Image Path or URL is required."
# Validate input using Pydantic
ImagePromptSchema(image_path_url=image_path_url)
if image_path_url.startswith("http"):
image_data = image_path_url
else:
try:
base64_image = self._encode_image(image_path_url)
image_data = f"data:image/jpeg;base64,{base64_image}"
except Exception as e:
return f"Error processing image: {str(e)}"
response = self.client.chat.completions.create(
model="gpt-4o-mini", model="gpt-4o-mini",
messages=[ messages=[
{ {
@@ -30,8 +67,8 @@ class VisionTool(BaseTool):
{"type": "text", "text": "What's in this image?"}, {"type": "text", "text": "What's in this image?"},
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": image_path_url}, "image_url": {"url": image_data},
}, }
], ],
} }
], ],
@@ -40,54 +77,9 @@ class VisionTool(BaseTool):
return response.choices[0].message.content return response.choices[0].message.content
def _run_local_images(self, client, image_path_url: str) -> str: except Exception as e:
base64_image = self._encode_image(image_path_url) return f"An error occurred: {str(e)}"
headers = { def _encode_image(self, image_path: str) -> str:
"Content-Type": "application/json",
"Authorization": f"Bearer {client.api_key}",
}
payload = {
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
)
return response.json()["choices"][0]["message"]["content"]
def _run(self, **kwargs) -> str:
client = OpenAI()
image_path_url = kwargs.get("image_path_url")
if not image_path_url:
return "Image Path or URL is required."
if "http" in image_path_url:
image_description = self._run_web_hosted_images(client, image_path_url)
else:
image_description = self._run_local_images(client, image_path_url)
return image_description
def _encode_image(self, image_path: str):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")