mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
feat: add file upload cache
This commit is contained in:
296
lib/crewai/src/crewai/utilities/files/upload_cache.py
Normal file
296
lib/crewai/src/crewai/utilities/files/upload_cache.py
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
"""Cache for tracking uploaded files to avoid redundant uploads."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from crewai.utilities.files.content_types import (
|
||||||
|
AudioFile,
|
||||||
|
ImageFile,
|
||||||
|
PDFFile,
|
||||||
|
TextFile,
|
||||||
|
VideoFile,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FileInput = AudioFile | ImageFile | PDFFile | TextFile | VideoFile
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedUpload:
|
||||||
|
"""Represents a cached file upload.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
file_id: Provider-specific file identifier.
|
||||||
|
provider: Name of the provider.
|
||||||
|
file_uri: Optional URI for accessing the file.
|
||||||
|
content_type: MIME type of the uploaded file.
|
||||||
|
uploaded_at: When the file was uploaded.
|
||||||
|
expires_at: When the upload expires (if applicable).
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
provider: str
|
||||||
|
file_uri: str | None
|
||||||
|
content_type: str
|
||||||
|
uploaded_at: datetime
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if this cached upload has expired.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if expired, False otherwise.
|
||||||
|
"""
|
||||||
|
if self.expires_at is None:
|
||||||
|
return False
|
||||||
|
return datetime.now(timezone.utc) >= self.expires_at
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadCache:
|
||||||
|
"""Thread-safe cache for tracking uploaded files.
|
||||||
|
|
||||||
|
Uses file content hash and provider as composite key to avoid
|
||||||
|
uploading the same file multiple times.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
default_ttl: Default time-to-live for cached entries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_ttl: timedelta = field(default_factory=lambda: timedelta(hours=24))
|
||||||
|
_cache: dict[tuple[str, str], CachedUpload] = field(default_factory=dict)
|
||||||
|
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||||
|
|
||||||
|
def _compute_hash(self, file: FileInput) -> str:
|
||||||
|
"""Compute a hash of file content for cache key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The file to hash.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SHA-256 hash of the file content.
|
||||||
|
"""
|
||||||
|
content = file.source.read()
|
||||||
|
return hashlib.sha256(content).hexdigest()
|
||||||
|
|
||||||
|
def get(self, file: FileInput, provider: str) -> CachedUpload | None:
|
||||||
|
"""Get a cached upload for a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The file to look up.
|
||||||
|
provider: The provider name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached upload if found and not expired, None otherwise.
|
||||||
|
"""
|
||||||
|
file_hash = self._compute_hash(file)
|
||||||
|
key = (file_hash, provider)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
cached = self._cache.get(key)
|
||||||
|
if cached is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cached.is_expired():
|
||||||
|
del self._cache[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
def get_by_hash(self, file_hash: str, provider: str) -> CachedUpload | None:
|
||||||
|
"""Get a cached upload by file hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_hash: Hash of the file content.
|
||||||
|
provider: The provider name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached upload if found and not expired, None otherwise.
|
||||||
|
"""
|
||||||
|
key = (file_hash, provider)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
cached = self._cache.get(key)
|
||||||
|
if cached is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cached.is_expired():
|
||||||
|
del self._cache[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached
|
||||||
|
|
||||||
|
def set(
|
||||||
|
self,
|
||||||
|
file: FileInput,
|
||||||
|
provider: str,
|
||||||
|
file_id: str,
|
||||||
|
file_uri: str | None = None,
|
||||||
|
expires_at: datetime | None = None,
|
||||||
|
) -> CachedUpload:
|
||||||
|
"""Cache an uploaded file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The file that was uploaded.
|
||||||
|
provider: The provider name.
|
||||||
|
file_id: Provider-specific file identifier.
|
||||||
|
file_uri: Optional URI for accessing the file.
|
||||||
|
expires_at: When the upload expires.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created cache entry.
|
||||||
|
"""
|
||||||
|
file_hash = self._compute_hash(file)
|
||||||
|
key = (file_hash, provider)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
cached = CachedUpload(
|
||||||
|
file_id=file_id,
|
||||||
|
provider=provider,
|
||||||
|
file_uri=file_uri,
|
||||||
|
content_type=file.content_type,
|
||||||
|
uploaded_at=now,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._cache[key] = cached
|
||||||
|
|
||||||
|
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
def set_by_hash(
|
||||||
|
self,
|
||||||
|
file_hash: str,
|
||||||
|
content_type: str,
|
||||||
|
provider: str,
|
||||||
|
file_id: str,
|
||||||
|
file_uri: str | None = None,
|
||||||
|
expires_at: datetime | None = None,
|
||||||
|
) -> CachedUpload:
|
||||||
|
"""Cache an uploaded file by hash.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_hash: Hash of the file content.
|
||||||
|
content_type: MIME type of the file.
|
||||||
|
provider: The provider name.
|
||||||
|
file_id: Provider-specific file identifier.
|
||||||
|
file_uri: Optional URI for accessing the file.
|
||||||
|
expires_at: When the upload expires.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created cache entry.
|
||||||
|
"""
|
||||||
|
key = (file_hash, provider)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
cached = CachedUpload(
|
||||||
|
file_id=file_id,
|
||||||
|
provider=provider,
|
||||||
|
file_uri=file_uri,
|
||||||
|
content_type=content_type,
|
||||||
|
uploaded_at=now,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._cache[key] = cached
|
||||||
|
|
||||||
|
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
def remove(self, file: FileInput, provider: str) -> bool:
|
||||||
|
"""Remove a cached upload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The file to remove.
|
||||||
|
provider: The provider name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was removed, False if not found.
|
||||||
|
"""
|
||||||
|
file_hash = self._compute_hash(file)
|
||||||
|
key = (file_hash, provider)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if key in self._cache:
|
||||||
|
del self._cache[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_by_file_id(self, file_id: str, provider: str) -> bool:
|
||||||
|
"""Remove a cached upload by file ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The file ID to remove.
|
||||||
|
provider: The provider name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was removed, False if not found.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
for key, cached in list(self._cache.items()):
|
||||||
|
if cached.file_id == file_id and cached.provider == provider:
|
||||||
|
del self._cache[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear_expired(self) -> int:
|
||||||
|
"""Remove all expired entries from the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries removed.
|
||||||
|
"""
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
for key in list(self._cache.keys()):
|
||||||
|
if self._cache[key].is_expired():
|
||||||
|
del self._cache[key]
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
if removed > 0:
|
||||||
|
logger.debug(f"Cleared {removed} expired cache entries")
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""Clear all entries from the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
logger.debug(f"Cleared {count} cache entries")
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_all_for_provider(self, provider: str) -> list[CachedUpload]:
|
||||||
|
"""Get all cached uploads for a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The provider name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of cached uploads for the provider.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return [
|
||||||
|
cached
|
||||||
|
for (_, p), cached in self._cache.items()
|
||||||
|
if p == provider and not cached.is_expired()
|
||||||
|
]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of cached entries."""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._cache)
|
||||||
Reference in New Issue
Block a user