mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Merge pull request #173 from mplachta/vision-tool-improvement
Vision Tool Refactoring and Simplify the code
This commit is contained in:
@@ -1,17 +1,29 @@
|
||||
import base64
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from typing import Type, Optional
|
||||
from pathlib import Path
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
class ImagePromptSchema(BaseModel):
|
||||
"""Input for Vision Tool."""
|
||||
|
||||
image_path_url: str = "The image path or URL."
|
||||
|
||||
@validator("image_path_url")
|
||||
def validate_image_path_url(cls, v: str) -> str:
|
||||
if v.startswith("http"):
|
||||
return v
|
||||
|
||||
path = Path(v)
|
||||
if not path.exists():
|
||||
raise ValueError(f"Image file does not exist: {v}")
|
||||
|
||||
# Validate supported formats
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
|
||||
|
||||
return v
|
||||
|
||||
class VisionTool(BaseTool):
|
||||
name: str = "Vision Tool"
|
||||
@@ -19,75 +31,55 @@ class VisionTool(BaseTool):
|
||||
"This tool uses OpenAI's Vision API to describe the contents of an image."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ImagePromptSchema
|
||||
_client: Optional[OpenAI] = None
|
||||
|
||||
def _run_web_hosted_images(self, client, image_path_url: str) -> str:
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_path_url},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _run_local_images(self, client, image_path_url: str) -> str:
|
||||
base64_image = self._encode_image(image_path_url)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {client.api_key}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
|
||||
)
|
||||
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
"""Cached OpenAI client instance."""
|
||||
if self._client is None:
|
||||
self._client = OpenAI()
|
||||
return self._client
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
client = OpenAI()
|
||||
try:
|
||||
image_path_url = kwargs.get("image_path_url")
|
||||
if not image_path_url:
|
||||
return "Image Path or URL is required."
|
||||
|
||||
# Validate input using Pydantic
|
||||
ImagePromptSchema(image_path_url=image_path_url)
|
||||
|
||||
if image_path_url.startswith("http"):
|
||||
image_data = image_path_url
|
||||
else:
|
||||
try:
|
||||
base64_image = self._encode_image(image_path_url)
|
||||
image_data = f"data:image/jpeg;base64,{base64_image}"
|
||||
except Exception as e:
|
||||
return f"Error processing image: {str(e)}"
|
||||
|
||||
image_path_url = kwargs.get("image_path_url")
|
||||
response = self.client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
if not image_path_url:
|
||||
return "Image Path or URL is required."
|
||||
return response.choices[0].message.content
|
||||
|
||||
if "http" in image_path_url:
|
||||
image_description = self._run_web_hosted_images(client, image_path_url)
|
||||
else:
|
||||
image_description = self._run_local_images(client, image_path_url)
|
||||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
|
||||
return image_description
|
||||
|
||||
def _encode_image(self, image_path: str):
|
||||
def _encode_image(self, image_path: str) -> str:
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user