Enhance image handling and analysis capabilities across platforms
- Updated the vision tool to accept both HTTP/HTTPS URLs and local file paths for image analysis. - Implemented caching of user-uploaded images in local directories to ensure reliable access for the vision tool, addressing issues with ephemeral URLs. - Enhanced platform adapters (Discord, Telegram, WhatsApp) to download and cache images, allowing for immediate analysis and enriched message context. - Added a new method to auto-analyze images attached by users, enriching the conversation with detailed descriptions. - Improved documentation for image handling processes and updated related functions for clarity and efficiency.
This commit is contained in:
parent
eb49936a60
commit
5404a8fcd8
7 changed files with 303 additions and 35 deletions
|
|
@ -6,10 +6,13 @@ and implement the required methods.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from enum import Enum
|
||||
|
||||
|
|
@ -20,6 +23,91 @@ from gateway.config import Platform, PlatformConfig
|
|||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image cache utilities
|
||||
#
|
||||
# When users send images on messaging platforms, we download them to a local
|
||||
# cache directory so they can be analyzed by the vision tool (which accepts
|
||||
# local file paths). This avoids issues with ephemeral platform URLs
|
||||
# (e.g. Telegram file URLs expire after ~1 hour).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default location: ~/.hermes/image_cache/
|
||||
IMAGE_CACHE_DIR = Path(os.path.expanduser("~/.hermes/image_cache"))
|
||||
|
||||
|
||||
def get_image_cache_dir() -> Path:
|
||||
"""Return the image cache directory, creating it if it doesn't exist."""
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return IMAGE_CACHE_DIR
|
||||
|
||||
|
||||
def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
"""
|
||||
Save raw image bytes to the cache and return the absolute file path.
|
||||
|
||||
Args:
|
||||
data: Raw image bytes.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
cache_dir = get_image_cache_dir()
|
||||
filename = f"img_{uuid.uuid4().hex[:12]}{ext}"
|
||||
filepath = cache_dir / filename
|
||||
filepath.write_bytes(data)
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg") -> str:
|
||||
"""
|
||||
Download an image from a URL and save it to the local cache.
|
||||
|
||||
Uses httpx for async download with a reasonable timeout.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
"""
|
||||
Delete cached images older than *max_age_hours*.
|
||||
|
||||
Returns the number of files removed.
|
||||
"""
|
||||
import time
|
||||
|
||||
cache_dir = get_image_cache_dir()
|
||||
cutoff = time.time() - (max_age_hours * 3600)
|
||||
removed = 0
|
||||
for f in cache_dir.iterdir():
|
||||
if f.is_file() and f.stat().st_mtime < cutoff:
|
||||
try:
|
||||
f.unlink()
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
TEXT = "text"
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from gateway.platforms.base import (
|
|||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -402,9 +403,31 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Build media URLs
|
||||
media_urls = [att.url for att in message.attachments]
|
||||
media_types = [att.content_type or "unknown" for att in message.attachments]
|
||||
# Build media URLs -- download image attachments to local cache so the
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
media_types = []
|
||||
for att in message.attachments:
|
||||
content_type = att.content_type or "unknown"
|
||||
if content_type.startswith("image/"):
|
||||
try:
|
||||
# Determine extension from content type (image/png -> .png)
|
||||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
|
||||
ext = ".jpg"
|
||||
cached_path = await cache_image_from_url(att.url, ext=ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user image: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Discord] Failed to cache image attachment: {e}", flush=True)
|
||||
# Fall back to the CDN URL if caching fails
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
else:
|
||||
# Non-image attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from gateway.platforms.base import (
|
|||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -303,7 +304,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await self.handle_message(event)
|
||||
|
||||
async def _handle_media_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming media messages."""
|
||||
"""Handle incoming media messages, downloading images to local cache."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
|
|
@ -327,6 +328,30 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
if msg.caption:
|
||||
event.text = msg.caption
|
||||
|
||||
# Download photo to local image cache so the vision tool can access it
|
||||
# even after Telegram's ephemeral file URLs expire (~1 hour).
|
||||
if msg.photo:
|
||||
try:
|
||||
# msg.photo is a list of PhotoSize sorted by size; take the largest
|
||||
photo = msg.photo[-1]
|
||||
file_obj = await photo.get_file()
|
||||
# Download the image bytes directly into memory
|
||||
image_bytes = await file_obj.download_as_bytearray()
|
||||
# Determine extension from the file path if available
|
||||
ext = ".jpg"
|
||||
if file_obj.file_path:
|
||||
for candidate in [".png", ".webp", ".gif", ".jpeg", ".jpg"]:
|
||||
if file_obj.file_path.lower().endswith(candidate):
|
||||
ext = candidate
|
||||
break
|
||||
# Save to cache and populate media_urls with the local path
|
||||
cached_path = cache_image_from_bytes(bytes(image_bytes), ext=ext)
|
||||
event.media_urls = [cached_path]
|
||||
event.media_types = [f"image/{ext.lstrip('.')}"]
|
||||
print(f"[Telegram] Cached user photo: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[Telegram] Failed to cache photo: {e}", flush=True)
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from gateway.platforms.base import (
|
|||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -267,7 +268,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = self._build_message_event(msg_data)
|
||||
event = await self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
|
|
@ -278,8 +279,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
|
||||
await asyncio.sleep(1) # Poll interval
|
||||
|
||||
def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data."""
|
||||
async def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data, downloading images to cache."""
|
||||
try:
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
|
|
@ -307,13 +308,34 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
# Download image media URLs to the local cache so the vision tool
|
||||
# can access them reliably regardless of URL expiration.
|
||||
raw_urls = data.get("mediaUrls", [])
|
||||
cached_urls = []
|
||||
media_types = []
|
||||
for url in raw_urls:
|
||||
if msg_type == MessageType.PHOTO and url.startswith(("http://", "https://")):
|
||||
try:
|
||||
cached_path = await cache_image_from_url(url, ext=".jpg")
|
||||
cached_urls.append(cached_path)
|
||||
media_types.append("image/jpeg")
|
||||
print(f"[{self.name}] Cached user image: {cached_path}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to cache image: {e}", flush=True)
|
||||
cached_urls.append(url) # Fall back to original URL
|
||||
media_types.append("image/jpeg")
|
||||
else:
|
||||
cached_urls.append(url)
|
||||
media_types.append("unknown")
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=data,
|
||||
message_id=data.get("messageId"),
|
||||
media_urls=data.get("mediaUrls", []),
|
||||
media_urls=cached_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
|
|
|
|||
106
gateway/run.py
106
gateway/run.py
|
|
@ -58,7 +58,7 @@ from gateway.session import (
|
|||
build_session_context_prompt,
|
||||
)
|
||||
from gateway.delivery import DeliveryRouter, DeliveryTarget
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
|
||||
|
||||
|
||||
class GatewayRunner:
|
||||
|
|
@ -298,10 +298,39 @@ class GatewayRunner:
|
|||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Auto-analyze images sent by the user
|
||||
#
|
||||
# If the user attached image(s), we run the vision tool eagerly so
|
||||
# the conversation model always receives a text description. The
|
||||
# local file path is also included so the model can re-examine the
|
||||
# image later with a more targeted question via vision_analyze.
|
||||
#
|
||||
# We filter to image paths only (by media_type) so that non-image
|
||||
# attachments (documents, audio, etc.) are not sent to the vision
|
||||
# tool even when they appear in the same message.
|
||||
# -----------------------------------------------------------------
|
||||
message_text = event.text or ""
|
||||
if event.media_urls:
|
||||
image_paths = []
|
||||
for i, path in enumerate(event.media_urls):
|
||||
# Check media_types if available; otherwise infer from message type
|
||||
mtype = event.media_types[i] if i < len(event.media_types) else ""
|
||||
is_image = (
|
||||
mtype.startswith("image/")
|
||||
or event.message_type == MessageType.PHOTO
|
||||
)
|
||||
if is_image:
|
||||
image_paths.append(path)
|
||||
if image_paths:
|
||||
message_text = await self._enrich_message_with_vision(
|
||||
message_text, image_paths
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the agent
|
||||
response = await self._run_agent(
|
||||
message=event.text,
|
||||
message=message_text,
|
||||
context_prompt=context_prompt,
|
||||
history=history,
|
||||
source=source,
|
||||
|
|
@ -320,10 +349,10 @@ class GatewayRunner:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# Append to transcript
|
||||
# Append to transcript (use the enriched message so vision context is preserved)
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
{"role": "user", "content": event.text, "timestamp": datetime.now().isoformat()}
|
||||
{"role": "user", "content": message_text, "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
|
|
@ -411,6 +440,75 @@ class GatewayRunner:
|
|||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
async def _enrich_message_with_vision(
|
||||
self,
|
||||
user_text: str,
|
||||
image_paths: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Auto-analyze user-attached images with the vision tool and prepend
|
||||
the descriptions to the message text.
|
||||
|
||||
Each image is analyzed with a general-purpose prompt. The resulting
|
||||
description *and* the local cache path are injected so the model can:
|
||||
1. Immediately understand what the user sent (no extra tool call).
|
||||
2. Re-examine the image with vision_analyze if it needs more detail.
|
||||
|
||||
Args:
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
|
||||
Returns:
|
||||
The enriched message string with vision descriptions prepended.
|
||||
"""
|
||||
from tools.vision_tools import vision_analyze_tool
|
||||
import json as _json
|
||||
|
||||
analysis_prompt = (
|
||||
"Describe everything visible in this image in thorough detail. "
|
||||
"Include any text, code, data, objects, people, layout, colors, "
|
||||
"and any other notable visual information."
|
||||
)
|
||||
|
||||
enriched_parts = []
|
||||
for path in image_paths:
|
||||
try:
|
||||
print(f"[gateway] Auto-analyzing user image: {path}", flush=True)
|
||||
result_json = await vision_analyze_tool(
|
||||
image_url=path,
|
||||
user_prompt=analysis_prompt,
|
||||
)
|
||||
result = _json.loads(result_json)
|
||||
if result.get("success"):
|
||||
description = result.get("analysis", "")
|
||||
enriched_parts.append(
|
||||
f"[User sent an image. Vision analysis:\n{description}]\n"
|
||||
f"[To examine this image further, use vision_analyze with "
|
||||
f"image_url: {path}]"
|
||||
)
|
||||
else:
|
||||
# Analysis failed -- still tell the model the image exists
|
||||
enriched_parts.append(
|
||||
f"[User sent an image but automatic analysis failed. "
|
||||
f"You can try analyzing it with vision_analyze using "
|
||||
f"image_url: {path}]"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Vision auto-analysis error: {e}", flush=True)
|
||||
enriched_parts.append(
|
||||
f"[User sent an image but automatic analysis encountered an error. "
|
||||
f"You can try analyzing it with vision_analyze using "
|
||||
f"image_url: {path}]"
|
||||
)
|
||||
|
||||
# Combine: vision descriptions first, then the user's original text
|
||||
if enriched_parts:
|
||||
prefix = "\n\n".join(enriched_parts)
|
||||
if user_text:
|
||||
return f"{prefix}\n\n{user_text}"
|
||||
return prefix
|
||||
return user_text
|
||||
|
||||
async def _run_agent(
|
||||
self,
|
||||
message: str,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue