mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Allow setting custom LLM for the vision tool (#294)
* Allow setting custom LLM for the vision tool Defaults to gpt-4o-mini otherwise * Enhance VisionTool with model management and improved initialization - Added support for setting a custom model identifier with a default of "gpt-4o-mini". - Introduced properties for model management, allowing dynamic updates and resetting of the LLM instance. - Updated the initialization method to accept an optional LLM and model parameter. - Refactored the image processing logic for clarity and efficiency. * docstrings * Add stop config --------- Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
This commit is contained in:
@@ -2,9 +2,9 @@ import base64
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from crewai import LLM
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, PrivateAttr, field_validator
|
||||
|
||||
|
||||
class ImagePromptSchema(BaseModel):
|
||||
@@ -32,19 +32,52 @@ class ImagePromptSchema(BaseModel):
|
||||
|
||||
|
||||
class VisionTool(BaseTool):
|
||||
"""Tool for analyzing images using vision models.
|
||||
|
||||
Args:
|
||||
llm: Optional LLM instance to use
|
||||
model: Model identifier to use if no LLM is provided
|
||||
"""
|
||||
|
||||
name: str = "Vision Tool"
|
||||
description: str = (
|
||||
"This tool uses OpenAI's Vision API to describe the contents of an image."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ImagePromptSchema
|
||||
_client: Optional[OpenAI] = None
|
||||
|
||||
_model: str = PrivateAttr(default="gpt-4o-mini")
|
||||
_llm: Optional[LLM] = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, llm: Optional[LLM] = None, model: str = "gpt-4o-mini", **kwargs):
|
||||
"""Initialize the vision tool.
|
||||
|
||||
Args:
|
||||
llm: Optional LLM instance to use
|
||||
model: Model identifier to use if no LLM is provided
|
||||
**kwargs: Additional arguments for the base tool
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._model = model
|
||||
self._llm = llm
|
||||
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
"""Cached OpenAI client instance."""
|
||||
if self._client is None:
|
||||
self._client = OpenAI()
|
||||
return self._client
|
||||
def model(self) -> str:
|
||||
"""Get the current model identifier."""
|
||||
return self._model
|
||||
|
||||
@model.setter
|
||||
def model(self, value: str) -> None:
|
||||
"""Set the model identifier and reset LLM if it was auto-created."""
|
||||
self._model = value
|
||||
if self._llm is not None and self._llm._model != value:
|
||||
self._llm = None
|
||||
|
||||
@property
|
||||
def llm(self) -> LLM:
|
||||
"""Get the LLM instance, creating one if needed."""
|
||||
if self._llm is None:
|
||||
self._llm = LLM(model=self._model, stop=["STOP", "END"])
|
||||
return self._llm
|
||||
|
||||
def _run(self, **kwargs) -> str:
|
||||
try:
|
||||
@@ -52,7 +85,6 @@ class VisionTool(BaseTool):
|
||||
if not image_path_url:
|
||||
return "Image Path or URL is required."
|
||||
|
||||
# Validate input using Pydantic
|
||||
ImagePromptSchema(image_path_url=image_path_url)
|
||||
|
||||
if image_path_url.startswith("http"):
|
||||
@@ -64,8 +96,7 @@ class VisionTool(BaseTool):
|
||||
except Exception as e:
|
||||
return f"Error processing image: {str(e)}"
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
response = self.llm.call(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
@@ -76,16 +107,21 @@ class VisionTool(BaseTool):
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
],
|
||||
}
|
||||
},
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
|
||||
def _encode_image(self, image_path: str) -> str:
|
||||
"""Encode an image file as base64.
|
||||
|
||||
Args:
|
||||
image_path: Path to the image file
|
||||
|
||||
Returns:
|
||||
Base64-encoded image data
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user