mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Merge pull request #173 from mplachta/vision-tool-improvement
Vision Tool Refactoring and Simplify the code
This commit is contained in:
@@ -1,17 +1,29 @@
|
|||||||
import base64
|
import base64
|
||||||
from typing import Type
|
from typing import Type, Optional
|
||||||
|
from pathlib import Path
|
||||||
import requests
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
|
||||||
class ImagePromptSchema(BaseModel):
|
class ImagePromptSchema(BaseModel):
|
||||||
"""Input for Vision Tool."""
|
"""Input for Vision Tool."""
|
||||||
|
|
||||||
image_path_url: str = "The image path or URL."
|
image_path_url: str = "The image path or URL."
|
||||||
|
|
||||||
|
@validator("image_path_url")
|
||||||
|
def validate_image_path_url(cls, v: str) -> str:
|
||||||
|
if v.startswith("http"):
|
||||||
|
return v
|
||||||
|
|
||||||
|
path = Path(v)
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError(f"Image file does not exist: {v}")
|
||||||
|
|
||||||
|
# Validate supported formats
|
||||||
|
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||||
|
if path.suffix.lower() not in valid_extensions:
|
||||||
|
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
class VisionTool(BaseTool):
|
class VisionTool(BaseTool):
|
||||||
name: str = "Vision Tool"
|
name: str = "Vision Tool"
|
||||||
@@ -19,9 +31,34 @@ class VisionTool(BaseTool):
|
|||||||
"This tool uses OpenAI's Vision API to describe the contents of an image."
|
"This tool uses OpenAI's Vision API to describe the contents of an image."
|
||||||
)
|
)
|
||||||
args_schema: Type[BaseModel] = ImagePromptSchema
|
args_schema: Type[BaseModel] = ImagePromptSchema
|
||||||
|
_client: Optional[OpenAI] = None
|
||||||
|
|
||||||
def _run_web_hosted_images(self, client, image_path_url: str) -> str:
|
@property
|
||||||
response = client.chat.completions.create(
|
def client(self) -> OpenAI:
|
||||||
|
"""Cached OpenAI client instance."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = OpenAI()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _run(self, **kwargs) -> str:
|
||||||
|
try:
|
||||||
|
image_path_url = kwargs.get("image_path_url")
|
||||||
|
if not image_path_url:
|
||||||
|
return "Image Path or URL is required."
|
||||||
|
|
||||||
|
# Validate input using Pydantic
|
||||||
|
ImagePromptSchema(image_path_url=image_path_url)
|
||||||
|
|
||||||
|
if image_path_url.startswith("http"):
|
||||||
|
image_data = image_path_url
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
base64_image = self._encode_image(image_path_url)
|
||||||
|
image_data = f"data:image/jpeg;base64,{base64_image}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error processing image: {str(e)}"
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
model="gpt-4o-mini",
|
model="gpt-4o-mini",
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
@@ -30,8 +67,8 @@ class VisionTool(BaseTool):
|
|||||||
{"type": "text", "text": "What's in this image?"},
|
{"type": "text", "text": "What's in this image?"},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": image_path_url},
|
"image_url": {"url": image_data},
|
||||||
},
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -40,54 +77,9 @@ class VisionTool(BaseTool):
|
|||||||
|
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def _run_local_images(self, client, image_path_url: str) -> str:
|
except Exception as e:
|
||||||
base64_image = self._encode_image(image_path_url)
|
return f"An error occurred: {str(e)}"
|
||||||
|
|
||||||
headers = {
|
def _encode_image(self, image_path: str) -> str:
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {client.api_key}",
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": "gpt-4o-mini",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "What's in this image?"},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 300,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.json()["choices"][0]["message"]["content"]
|
|
||||||
|
|
||||||
def _run(self, **kwargs) -> str:
|
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
image_path_url = kwargs.get("image_path_url")
|
|
||||||
|
|
||||||
if not image_path_url:
|
|
||||||
return "Image Path or URL is required."
|
|
||||||
|
|
||||||
if "http" in image_path_url:
|
|
||||||
image_description = self._run_web_hosted_images(client, image_path_url)
|
|
||||||
else:
|
|
||||||
image_description = self._run_local_images(client, image_path_url)
|
|
||||||
|
|
||||||
return image_description
|
|
||||||
|
|
||||||
def _encode_image(self, image_path: str):
|
|
||||||
with open(image_path, "rb") as image_file:
|
with open(image_path, "rb") as image_file:
|
||||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|||||||
Reference in New Issue
Block a user