Merge pull request #173 from mplachta/vision-tool-improvement

Vision Tool Refactoring and Simplify the code
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-01-03 13:31:23 -05:00
committed by GitHub

View File

@@ -1,17 +1,29 @@
import base64 import base64
from typing import Type from typing import Type, Optional
from pathlib import Path
import requests
from crewai.tools import BaseTool from crewai.tools import BaseTool
from openai import OpenAI from openai import OpenAI
from pydantic import BaseModel from pydantic import BaseModel, validator
class ImagePromptSchema(BaseModel): class ImagePromptSchema(BaseModel):
"""Input for Vision Tool.""" """Input for Vision Tool."""
image_path_url: str = "The image path or URL." image_path_url: str = "The image path or URL."
@validator("image_path_url")
def validate_image_path_url(cls, v: str) -> str:
if v.startswith("http"):
return v
path = Path(v)
if not path.exists():
raise ValueError(f"Image file does not exist: {v}")
# Validate supported formats
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if path.suffix.lower() not in valid_extensions:
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
return v
class VisionTool(BaseTool): class VisionTool(BaseTool):
name: str = "Vision Tool" name: str = "Vision Tool"
@@ -19,75 +31,55 @@ class VisionTool(BaseTool):
"This tool uses OpenAI's Vision API to describe the contents of an image." "This tool uses OpenAI's Vision API to describe the contents of an image."
) )
args_schema: Type[BaseModel] = ImagePromptSchema args_schema: Type[BaseModel] = ImagePromptSchema
_client: Optional[OpenAI] = None
def _run_web_hosted_images(self, client, image_path_url: str) -> str: @property
response = client.chat.completions.create( def client(self) -> OpenAI:
model="gpt-4o-mini", """Cached OpenAI client instance."""
messages=[ if self._client is None:
{ self._client = OpenAI()
"role": "user", return self._client
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_path_url},
},
],
}
],
max_tokens=300,
)
return response.choices[0].message.content
def _run_local_images(self, client, image_path_url: str) -> str:
base64_image = self._encode_image(image_path_url)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {client.api_key}",
}
payload = {
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
"max_tokens": 300,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
)
return response.json()["choices"][0]["message"]["content"]
def _run(self, **kwargs) -> str: def _run(self, **kwargs) -> str:
client = OpenAI() try:
image_path_url = kwargs.get("image_path_url")
if not image_path_url:
return "Image Path or URL is required."
# Validate input using Pydantic
ImagePromptSchema(image_path_url=image_path_url)
if image_path_url.startswith("http"):
image_data = image_path_url
else:
try:
base64_image = self._encode_image(image_path_url)
image_data = f"data:image/jpeg;base64,{base64_image}"
except Exception as e:
return f"Error processing image: {str(e)}"
image_path_url = kwargs.get("image_path_url") response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_data},
}
],
}
],
max_tokens=300,
)
if not image_path_url: return response.choices[0].message.content
return "Image Path or URL is required."
if "http" in image_path_url: except Exception as e:
image_description = self._run_web_hosted_images(client, image_path_url) return f"An error occurred: {str(e)}"
else:
image_description = self._run_local_images(client, image_path_url)
return image_description def _encode_image(self, image_path: str) -> str:
def _encode_image(self, image_path: str):
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8") return base64.b64encode(image_file.read()).decode("utf-8")