Merge pull request #173 from mplachta/vision-tool-improvement

Vision Tool Refactoring and Simplify the code
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-01-03 13:31:23 -05:00
committed by GitHub

View File

@@ -1,17 +1,29 @@
import base64
from typing import Type
import requests
from typing import Type, Optional
from pathlib import Path
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel
from pydantic import BaseModel, validator
class ImagePromptSchema(BaseModel):
"""Input for Vision Tool."""
image_path_url: str = "The image path or URL."
@validator("image_path_url")
def validate_image_path_url(cls, v: str) -> str:
if v.startswith("http"):
return v
path = Path(v)
if not path.exists():
raise ValueError(f"Image file does not exist: {v}")
# Validate supported formats
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if path.suffix.lower() not in valid_extensions:
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
return v
class VisionTool(BaseTool):
name: str = "Vision Tool"
@@ -19,75 +31,55 @@ class VisionTool(BaseTool):
"This tool uses OpenAI's Vision API to describe the contents of an image."
)
args_schema: Type[BaseModel] = ImagePromptSchema
_client: Optional[OpenAI] = None
def _run_web_hosted_images(self, client, image_path_url: str) -> str:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_path_url},
},
],
}
],
max_tokens=300,
)
return response.choices[0].message.content
def _run_local_images(self, client, image_path_url: str) -> str:
base64_image = self._encode_image(image_path_url)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {client.api_key}",
}
payload = {
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
)
return response.json()["choices"][0]["message"]["content"]
@property
def client(self) -> OpenAI:
"""Cached OpenAI client instance."""
if self._client is None:
self._client = OpenAI()
return self._client
def _run(self, **kwargs) -> str:
client = OpenAI()
try:
image_path_url = kwargs.get("image_path_url")
if not image_path_url:
return "Image Path or URL is required."
# Validate input using Pydantic
ImagePromptSchema(image_path_url=image_path_url)
if image_path_url.startswith("http"):
image_data = image_path_url
else:
try:
base64_image = self._encode_image(image_path_url)
image_data = f"data:image/jpeg;base64,{base64_image}"
except Exception as e:
return f"Error processing image: {str(e)}"
image_path_url = kwargs.get("image_path_url")
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_data},
}
],
}
],
max_tokens=300,
)
if not image_path_url:
return "Image Path or URL is required."
return response.choices[0].message.content
if "http" in image_path_url:
image_description = self._run_web_hosted_images(client, image_path_url)
else:
image_description = self._run_local_images(client, image_path_url)
except Exception as e:
return f"An error occurred: {str(e)}"
return image_description
def _encode_image(self, image_path: str):
def _encode_image(self, image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")