Merge PR #184: feat: Home Assistant integration (REST tools + WebSocket gateway)

Authored by 0xbyt4. Adds smart home control via REST tools (ha_list_entities,
ha_get_state, ha_call_service) with domain blocklist and entity_id validation,
plus WebSocket gateway adapter for real-time event monitoring.

Also includes Gemini 3 thought_signature preservation fix (extra_content on
tool calls) needed for multi-turn tool calling via OpenRouter.
This commit is contained in:
teknium1 2026-03-03 05:01:39 -08:00
commit db0521ce0e
15 changed files with 2494 additions and 7 deletions

View file

@ -26,6 +26,7 @@ class Platform(Enum):
DISCORD = "discord"
WHATSAPP = "whatsapp"
SLACK = "slack"
HOMEASSISTANT = "homeassistant"
@dataclass
@ -378,6 +379,17 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
)
# Home Assistant
hass_token = os.getenv("HASS_TOKEN")
if hass_token:
if Platform.HOMEASSISTANT not in config.platforms:
config.platforms[Platform.HOMEASSISTANT] = PlatformConfig()
config.platforms[Platform.HOMEASSISTANT].enabled = True
config.platforms[Platform.HOMEASSISTANT].token = hass_token
hass_url = os.getenv("HASS_URL")
if hass_url:
config.platforms[Platform.HOMEASSISTANT].extra["url"] = hass_url
# Session settings
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
if idle_minutes:

View file

@ -0,0 +1,416 @@
"""
Home Assistant platform adapter.
Connects to the HA WebSocket API for real-time event monitoring.
State-change events are converted to MessageEvent objects and forwarded
to the agent for processing. Outbound messages are delivered as HA
persistent notifications.
Requires:
- aiohttp (already in messaging extras)
- HASS_TOKEN env var (Long-Lived Access Token)
- HASS_URL env var (default: http://homeassistant.local:8123)
"""
import asyncio
import json
import logging
import os
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
try:
import aiohttp
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
aiohttp = None # type: ignore[assignment]
import sys
from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
MessageType,
SendResult,
)
logger = logging.getLogger(__name__)
def check_ha_requirements() -> bool:
"""Check if Home Assistant dependencies are available and configured."""
if not AIOHTTP_AVAILABLE:
return False
if not os.getenv("HASS_TOKEN"):
return False
return True
class HomeAssistantAdapter(BasePlatformAdapter):
"""
Home Assistant WebSocket adapter.
Subscribes to ``state_changed`` events and forwards them as
MessageEvent objects. Supports domain/entity filtering and
per-entity cooldowns to avoid event floods.
"""
MAX_MESSAGE_LENGTH = 4096
# Reconnection backoff schedule (seconds)
_BACKOFF_STEPS = [5, 10, 30, 60]
def __init__(self, config: PlatformConfig):
super().__init__(config, Platform.HOMEASSISTANT)
# Connection state
self._session: Optional["aiohttp.ClientSession"] = None
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
self._listen_task: Optional[asyncio.Task] = None
self._msg_id: int = 0
# Configuration from extra
extra = config.extra or {}
token = config.token or os.getenv("HASS_TOKEN", "")
url = extra.get("url") or os.getenv("HASS_URL", "http://homeassistant.local:8123")
self._hass_url: str = url.rstrip("/")
self._hass_token: str = token
# Event filtering
self._watch_domains: Set[str] = set(extra.get("watch_domains", []))
self._watch_entities: Set[str] = set(extra.get("watch_entities", []))
self._ignore_entities: Set[str] = set(extra.get("ignore_entities", []))
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
# Cooldown tracking: entity_id -> last_event_timestamp
self._last_event_time: Dict[str, float] = {}
def _next_id(self) -> int:
"""Return the next WebSocket message ID."""
self._msg_id += 1
return self._msg_id
# ------------------------------------------------------------------
# Connection lifecycle
# ------------------------------------------------------------------
async def connect(self) -> bool:
"""Connect to HA WebSocket API and subscribe to events."""
if not AIOHTTP_AVAILABLE:
print(f"[{self.name}] aiohttp not installed. Run: pip install aiohttp")
return False
if not self._hass_token:
print(f"[{self.name}] No HASS_TOKEN configured")
return False
try:
success = await self._ws_connect()
if not success:
return False
# Start background listener
self._listen_task = asyncio.create_task(self._listen_loop())
self._running = True
print(f"[{self.name}] Connected to {self._hass_url}")
return True
except Exception as e:
print(f"[{self.name}] Failed to connect: {e}")
return False
async def _ws_connect(self) -> bool:
"""Establish WebSocket connection and authenticate."""
ws_url = self._hass_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/api/websocket"
self._session = aiohttp.ClientSession()
self._ws = await self._session.ws_connect(ws_url, heartbeat=30)
# Step 1: Receive auth_required
msg = await self._ws.receive_json()
if msg.get("type") != "auth_required":
logger.error("Expected auth_required, got: %s", msg.get("type"))
await self._cleanup_ws()
return False
# Step 2: Send auth
await self._ws.send_json({
"type": "auth",
"access_token": self._hass_token,
})
# Step 3: Wait for auth_ok
msg = await self._ws.receive_json()
if msg.get("type") != "auth_ok":
logger.error("Auth failed: %s", msg)
await self._cleanup_ws()
return False
# Step 4: Subscribe to state_changed events
sub_id = self._next_id()
await self._ws.send_json({
"id": sub_id,
"type": "subscribe_events",
"event_type": "state_changed",
})
# Verify subscription acknowledgement
msg = await self._ws.receive_json()
if not msg.get("success"):
logger.error("Failed to subscribe to events: %s", msg)
await self._cleanup_ws()
return False
return True
async def _cleanup_ws(self) -> None:
"""Close WebSocket and session."""
if self._ws and not self._ws.closed:
await self._ws.close()
self._ws = None
if self._session and not self._session.closed:
await self._session.close()
self._session = None
async def disconnect(self) -> None:
"""Disconnect from Home Assistant."""
self._running = False
if self._listen_task:
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
self._listen_task = None
await self._cleanup_ws()
print(f"[{self.name}] Disconnected")
# ------------------------------------------------------------------
# Event listener
# ------------------------------------------------------------------
async def _listen_loop(self) -> None:
"""Main event loop with automatic reconnection."""
backoff_idx = 0
while self._running:
try:
await self._read_events()
except asyncio.CancelledError:
return
except Exception as e:
logger.warning("[%s] WebSocket error: %s", self.name, e)
if not self._running:
return
# Reconnect with backoff
delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)]
print(f"[{self.name}] Reconnecting in {delay}s...")
await asyncio.sleep(delay)
backoff_idx += 1
try:
await self._cleanup_ws()
success = await self._ws_connect()
if success:
backoff_idx = 0 # Reset on successful reconnect
print(f"[{self.name}] Reconnected")
except Exception as e:
logger.warning("[%s] Reconnection failed: %s", self.name, e)
async def _read_events(self) -> None:
"""Read events from WebSocket until disconnected."""
if self._ws is None or self._ws.closed:
return
async for ws_msg in self._ws:
if ws_msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(ws_msg.data)
if data.get("type") == "event":
await self._handle_ha_event(data.get("event", {}))
except json.JSONDecodeError:
logger.debug("Invalid JSON from HA WS: %s", ws_msg.data[:200])
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
break
async def _handle_ha_event(self, event: Dict[str, Any]) -> None:
"""Process a state_changed event from Home Assistant."""
event_data = event.get("data", {})
entity_id: str = event_data.get("entity_id", "")
if not entity_id:
return
# Apply ignore filter
if entity_id in self._ignore_entities:
return
# Apply domain/entity watch filters
domain = entity_id.split(".")[0] if "." in entity_id else ""
if self._watch_domains or self._watch_entities:
domain_match = domain in self._watch_domains if self._watch_domains else False
entity_match = entity_id in self._watch_entities if self._watch_entities else False
if not domain_match and not entity_match:
return
# Apply cooldown
now = time.time()
last = self._last_event_time.get(entity_id, 0)
if (now - last) < self._cooldown_seconds:
return
self._last_event_time[entity_id] = now
# Build human-readable message
old_state = event_data.get("old_state", {})
new_state = event_data.get("new_state", {})
message = self._format_state_change(entity_id, old_state, new_state)
if not message:
return
# Build MessageEvent and forward to handler
source = self.build_source(
chat_id="ha_events",
chat_name="Home Assistant Events",
chat_type="channel",
user_id="homeassistant",
user_name="Home Assistant",
)
msg_event = MessageEvent(
text=message,
message_type=MessageType.TEXT,
source=source,
message_id=f"ha_{entity_id}_{int(now)}",
timestamp=datetime.now(),
)
await self.handle_message(msg_event)
@staticmethod
def _format_state_change(
entity_id: str,
old_state: Dict[str, Any],
new_state: Dict[str, Any],
) -> Optional[str]:
"""Convert a state_changed event into a human-readable description."""
if not new_state:
return None
old_val = old_state.get("state", "unknown") if old_state else "unknown"
new_val = new_state.get("state", "unknown")
# Skip if state didn't actually change
if old_val == new_val:
return None
friendly_name = new_state.get("attributes", {}).get("friendly_name", entity_id)
domain = entity_id.split(".")[0] if "." in entity_id else ""
# Domain-specific formatting
if domain == "climate":
attrs = new_state.get("attributes", {})
temp = attrs.get("current_temperature", "?")
target = attrs.get("temperature", "?")
return (
f"[Home Assistant] {friendly_name}: HVAC mode changed from "
f"'{old_val}' to '{new_val}' (current: {temp}, target: {target})"
)
if domain == "sensor":
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
return (
f"[Home Assistant] {friendly_name}: changed from "
f"{old_val}{unit} to {new_val}{unit}"
)
if domain == "binary_sensor":
return (
f"[Home Assistant] {friendly_name}: "
f"{'triggered' if new_val == 'on' else 'cleared'} "
f"(was {'triggered' if old_val == 'on' else 'cleared'})"
)
if domain in ("light", "switch", "fan"):
return (
f"[Home Assistant] {friendly_name}: turned "
f"{'on' if new_val == 'on' else 'off'}"
)
if domain == "alarm_control_panel":
return (
f"[Home Assistant] {friendly_name}: alarm state changed from "
f"'{old_val}' to '{new_val}'"
)
# Generic fallback
return (
f"[Home Assistant] {friendly_name} ({entity_id}): "
f"changed from '{old_val}' to '{new_val}'"
)
# ------------------------------------------------------------------
# Outbound messaging
# ------------------------------------------------------------------
async def send(
self,
chat_id: str,
content: str,
reply_to: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send a notification via HA REST API (persistent_notification.create).
Uses the REST API instead of WebSocket to avoid a race condition
with the event listener loop that reads from the same WS connection.
"""
url = f"{self._hass_url}/api/services/persistent_notification/create"
headers = {
"Authorization": f"Bearer {self._hass_token}",
"Content-Type": "application/json",
}
payload = {
"title": "Hermes Agent",
"message": content[:self.MAX_MESSAGE_LENGTH],
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status < 300:
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
else:
body = await resp.text()
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
except asyncio.TimeoutError:
return SendResult(success=False, error="Timeout sending notification to HA")
except Exception as e:
return SendResult(success=False, error=str(e))
async def send_typing(self, chat_id: str) -> None:
"""No typing indicator for Home Assistant."""
pass
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
"""Return basic info about the HA event channel."""
return {
"name": "Home Assistant Events",
"type": "channel",
"url": self._hass_url,
}

View file

@ -516,6 +516,13 @@ class GatewayRunner:
return None
return SlackAdapter(config)
elif platform == Platform.HOMEASSISTANT:
from gateway.platforms.homeassistant import HomeAssistantAdapter, check_ha_requirements
if not check_ha_requirements():
logger.warning("HomeAssistant: aiohttp not installed or HASS_TOKEN not set")
return None
return HomeAssistantAdapter(config)
return None
def _is_user_authorized(self, source: SessionSource) -> bool:
@ -529,6 +536,12 @@ class GatewayRunner:
4. Global allow-all (GATEWAY_ALLOW_ALL_USERS=true)
5. Default: deny
"""
# Home Assistant events are system-generated (state changes), not
# user-initiated messages. The HASS_TOKEN already authenticates the
# connection, so HA events are always authorized.
if source.platform == Platform.HOMEASSISTANT:
return True
user_id = source.user_id
if not user_id:
return False

View file

@ -94,6 +94,7 @@ def _discover_tools():
"tools.process_registry",
"tools.send_message_tool",
"tools.honcho_tools",
"tools.homeassistant_tool",
]
import importlib
for mod_name in _modules:

View file

@ -48,6 +48,7 @@ tts-premium = ["elevenlabs"]
pty = ["ptyprocess>=0.7.0"]
honcho = ["honcho-ai>=2.0.1"]
mcp = ["mcp>=1.2.0"]
homeassistant = ["aiohttp>=3.9.0"]
all = [
"hermes-agent[modal]",
"hermes-agent[messaging]",
@ -59,6 +60,7 @@ all = [
"hermes-agent[pty]",
"hermes-agent[honcho]",
"hermes-agent[mcp]",
"hermes-agent[homeassistant]",
]
[project.scripts]

View file

@ -2212,7 +2212,7 @@ class AIAgent:
response_item_id if isinstance(response_item_id, str) else None,
)
tool_calls.append({
tc_dict = {
"id": call_id,
"call_id": call_id,
"response_item_id": response_item_id,
@ -2222,7 +2222,15 @@ class AIAgent:
"arguments": tool_call.function.arguments
},
}
)
# Preserve extra_content (e.g. Gemini thought_signature) so it
# is sent back on subsequent API calls. Without this, Gemini 3
# thinking models reject the request with a 400 error.
extra = getattr(tool_call, "extra_content", None)
if extra is not None:
if hasattr(extra, "model_dump"):
extra = extra.model_dump()
tc_dict["extra_content"] = extra
tool_calls.append(tc_dict)
msg["tool_calls"] = tool_calls
return msg

0
tests/fakes/__init__.py Normal file
View file

View file

@ -0,0 +1,288 @@
"""Fake Home Assistant server for integration testing.
Provides a real HTTP + WebSocket server (via aiohttp.web) that mimics the
Home Assistant API surface used by hermes-agent:
- ``/api/websocket`` -- WebSocket auth handshake + event push
- ``/api/states`` -- GET all entity states
- ``/api/states/{entity_id}`` -- GET single entity state
- ``/api/services/{domain}/{service}`` -- POST service call
- ``/api/services/persistent_notification/create`` -- POST notification
Usage::
async with FakeHAServer(token="test-token") as server:
url = server.url # e.g. "http://127.0.0.1:54321"
await server.push_event(event_data)
assert server.received_notifications # verify what arrived
"""
import asyncio
import json
from typing import Any, Dict, List, Optional
import aiohttp
from aiohttp import web
from aiohttp.test_utils import TestServer
# -- Sample entity data -------------------------------------------------------
ENTITY_STATES: List[Dict[str, Any]] = [
{
"entity_id": "light.bedroom",
"state": "on",
"attributes": {"friendly_name": "Bedroom Light", "brightness": 200},
"last_changed": "2025-01-15T10:30:00+00:00",
"last_updated": "2025-01-15T10:30:00+00:00",
},
{
"entity_id": "light.kitchen",
"state": "off",
"attributes": {"friendly_name": "Kitchen Light"},
"last_changed": "2025-01-15T09:00:00+00:00",
"last_updated": "2025-01-15T09:00:00+00:00",
},
{
"entity_id": "sensor.temperature",
"state": "22.5",
"attributes": {
"friendly_name": "Kitchen Temperature",
"unit_of_measurement": "C",
},
"last_changed": "2025-01-15T10:00:00+00:00",
"last_updated": "2025-01-15T10:00:00+00:00",
},
{
"entity_id": "switch.fan",
"state": "on",
"attributes": {"friendly_name": "Living Room Fan"},
"last_changed": "2025-01-15T08:00:00+00:00",
"last_updated": "2025-01-15T08:00:00+00:00",
},
{
"entity_id": "climate.thermostat",
"state": "heat",
"attributes": {
"friendly_name": "Main Thermostat",
"current_temperature": 21,
"temperature": 23,
},
"last_changed": "2025-01-15T07:00:00+00:00",
"last_updated": "2025-01-15T07:00:00+00:00",
},
]
class FakeHAServer:
"""In-process fake Home Assistant for integration tests.
Parameters
----------
token : str
The expected Bearer token for authentication.
"""
def __init__(self, token: str = "test-token-123"):
self.token = token
# Observability -- tests inspect these after exercising the adapter.
self.received_service_calls: List[Dict[str, Any]] = []
self.received_notifications: List[Dict[str, Any]] = []
# Control -- tests push events, server forwards them over WS.
self._event_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
# Flag to simulate auth rejection.
self.reject_auth = False
# Flag to simulate server errors.
self.force_500 = False
# Internal bookkeeping.
self._app: Optional[web.Application] = None
self._server: Optional[TestServer] = None
self._ws_connections: List[web.WebSocketResponse] = []
# -- Public helpers --------------------------------------------------------
@property
def url(self) -> str:
"""Base URL of the running server, e.g. ``http://127.0.0.1:12345``."""
assert self._server is not None, "Server not started"
host = self._server.host
port = self._server.port
return f"http://{host}:{port}"
async def push_event(self, event_data: Dict[str, Any]) -> None:
"""Enqueue a state_changed event for delivery over WebSocket."""
await self._event_queue.put(event_data)
# -- Lifecycle -------------------------------------------------------------
async def start(self) -> None:
self._app = self._build_app()
self._server = TestServer(self._app)
await self._server.start_server()
async def stop(self) -> None:
# Close any remaining WS connections.
for ws in self._ws_connections:
if not ws.closed:
await ws.close()
self._ws_connections.clear()
if self._server is not None:
await self._server.close()
async def __aenter__(self) -> "FakeHAServer":
await self.start()
return self
async def __aexit__(self, *exc) -> None:
await self.stop()
# -- Application construction ----------------------------------------------
def _build_app(self) -> web.Application:
app = web.Application()
app.router.add_get("/api/websocket", self._handle_ws)
app.router.add_get("/api/states", self._handle_get_states)
app.router.add_get("/api/states/{entity_id}", self._handle_get_state)
# Notification endpoint must be registered before the generic service
# route so that it takes priority.
app.router.add_post(
"/api/services/persistent_notification/create",
self._handle_notification,
)
app.router.add_post(
"/api/services/{domain}/{service}",
self._handle_call_service,
)
return app
# -- Auth helper -----------------------------------------------------------
def _check_rest_auth(self, request: web.Request) -> Optional[web.Response]:
"""Return a 401 response if the Bearer token is wrong, else None."""
auth = request.headers.get("Authorization", "")
if auth != f"Bearer {self.token}":
return web.Response(status=401, text="Unauthorized")
if self.force_500:
return web.Response(status=500, text="Internal Server Error")
return None
# -- WebSocket handler -----------------------------------------------------
async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
self._ws_connections.append(ws)
# Step 1: auth_required
await ws.send_json({"type": "auth_required", "ha_version": "2025.1.0"})
# Step 2: receive auth
msg = await ws.receive()
if msg.type != aiohttp.WSMsgType.TEXT:
await ws.close()
return ws
auth_msg = json.loads(msg.data)
# Step 3: validate
if self.reject_auth or auth_msg.get("access_token") != self.token:
await ws.send_json({"type": "auth_invalid", "message": "Invalid token"})
await ws.close()
return ws
await ws.send_json({"type": "auth_ok", "ha_version": "2025.1.0"})
# Step 4: subscribe_events
msg = await ws.receive()
if msg.type != aiohttp.WSMsgType.TEXT:
await ws.close()
return ws
sub_msg = json.loads(msg.data)
sub_id = sub_msg.get("id", 1)
# Step 5: ACK
await ws.send_json({
"id": sub_id,
"type": "result",
"success": True,
"result": None,
})
# Step 6: push events from queue until closed
try:
while not ws.closed:
try:
event_data = await asyncio.wait_for(
self._event_queue.get(), timeout=0.1,
)
await ws.send_json({
"id": sub_id,
"type": "event",
"event": event_data,
})
except asyncio.TimeoutError:
continue
except (ConnectionResetError, asyncio.CancelledError):
pass
return ws
# -- REST handlers ---------------------------------------------------------
async def _handle_get_states(self, request: web.Request) -> web.Response:
err = self._check_rest_auth(request)
if err:
return err
return web.json_response(ENTITY_STATES)
async def _handle_get_state(self, request: web.Request) -> web.Response:
err = self._check_rest_auth(request)
if err:
return err
entity_id = request.match_info["entity_id"]
for s in ENTITY_STATES:
if s["entity_id"] == entity_id:
return web.json_response(s)
return web.Response(status=404, text=f"Entity {entity_id} not found")
async def _handle_notification(self, request: web.Request) -> web.Response:
err = self._check_rest_auth(request)
if err:
return err
body = await request.json()
self.received_notifications.append(body)
return web.json_response([])
async def _handle_call_service(self, request: web.Request) -> web.Response:
err = self._check_rest_auth(request)
if err:
return err
domain = request.match_info["domain"]
service = request.match_info["service"]
body = await request.json()
self.received_service_calls.append({
"domain": domain,
"service": service,
"data": body,
})
# Return affected entities (mimics real HA behaviour for light/switch).
affected = []
entity_id = body.get("entity_id")
if entity_id:
new_state = "on" if service == "turn_on" else "off"
for s in ENTITY_STATES:
if s["entity_id"] == entity_id:
affected.append({
"entity_id": entity_id,
"state": new_state,
"attributes": s.get("attributes", {}),
})
break
return web.json_response(affected)

View file

@ -0,0 +1,604 @@
"""Tests for the Home Assistant gateway adapter.
Tests real logic: state change formatting, event filtering pipeline,
cooldown behavior, config integration, and adapter initialization.
"""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import (
GatewayConfig,
Platform,
PlatformConfig,
)
from gateway.platforms.homeassistant import (
HomeAssistantAdapter,
check_ha_requirements,
)
# ---------------------------------------------------------------------------
# check_ha_requirements
# ---------------------------------------------------------------------------
class TestCheckRequirements:
def test_returns_false_without_token(self, monkeypatch):
monkeypatch.delenv("HASS_TOKEN", raising=False)
assert check_ha_requirements() is False
def test_returns_true_with_token(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "test-token")
assert check_ha_requirements() is True
@patch("gateway.platforms.homeassistant.AIOHTTP_AVAILABLE", False)
def test_returns_false_without_aiohttp(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "test-token")
assert check_ha_requirements() is False
# ---------------------------------------------------------------------------
# _format_state_change - pure function, all domain branches
# ---------------------------------------------------------------------------
class TestFormatStateChange:
@staticmethod
def fmt(entity_id, old_state, new_state):
return HomeAssistantAdapter._format_state_change(entity_id, old_state, new_state)
def test_climate_includes_temperatures(self):
msg = self.fmt(
"climate.thermostat",
{"state": "off"},
{"state": "heat", "attributes": {
"friendly_name": "Main Thermostat",
"current_temperature": 21.5,
"temperature": 23,
}},
)
assert "Main Thermostat" in msg
assert "'off'" in msg and "'heat'" in msg
assert "21.5" in msg and "23" in msg
def test_sensor_includes_unit(self):
msg = self.fmt(
"sensor.temperature",
{"state": "22.5"},
{"state": "25.1", "attributes": {
"friendly_name": "Living Room Temp",
"unit_of_measurement": "C",
}},
)
assert "22.5C" in msg and "25.1C" in msg
assert "Living Room Temp" in msg
def test_sensor_without_unit(self):
msg = self.fmt(
"sensor.count",
{"state": "5"},
{"state": "10", "attributes": {"friendly_name": "Counter"}},
)
assert "5" in msg and "10" in msg
def test_binary_sensor_on(self):
msg = self.fmt(
"binary_sensor.motion",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Hallway Motion"}},
)
assert "triggered" in msg
assert "Hallway Motion" in msg
def test_binary_sensor_off(self):
msg = self.fmt(
"binary_sensor.door",
{"state": "on"},
{"state": "off", "attributes": {"friendly_name": "Front Door"}},
)
assert "cleared" in msg
def test_light_turned_on(self):
msg = self.fmt(
"light.bedroom",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Bedroom Light"}},
)
assert "turned on" in msg
def test_switch_turned_off(self):
msg = self.fmt(
"switch.heater",
{"state": "on"},
{"state": "off", "attributes": {"friendly_name": "Heater"}},
)
assert "turned off" in msg
def test_fan_domain_uses_light_switch_branch(self):
msg = self.fmt(
"fan.ceiling",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Ceiling Fan"}},
)
assert "turned on" in msg
def test_alarm_panel(self):
msg = self.fmt(
"alarm_control_panel.home",
{"state": "disarmed"},
{"state": "armed_away", "attributes": {"friendly_name": "Home Alarm"}},
)
assert "Home Alarm" in msg
assert "armed_away" in msg and "disarmed" in msg
def test_generic_domain_includes_entity_id(self):
msg = self.fmt(
"automation.morning",
{"state": "off"},
{"state": "on", "attributes": {"friendly_name": "Morning Routine"}},
)
assert "automation.morning" in msg
assert "Morning Routine" in msg
def test_same_state_returns_none(self):
assert self.fmt(
"sensor.temp",
{"state": "22"},
{"state": "22", "attributes": {"friendly_name": "Temp"}},
) is None
def test_empty_new_state_returns_none(self):
assert self.fmt("light.x", {"state": "on"}, {}) is None
def test_no_old_state_uses_unknown(self):
msg = self.fmt(
"light.new",
None,
{"state": "on", "attributes": {"friendly_name": "New Light"}},
)
assert msg is not None
assert "New Light" in msg
def test_uses_entity_id_when_no_friendly_name(self):
msg = self.fmt(
"sensor.unnamed",
{"state": "1"},
{"state": "2", "attributes": {}},
)
assert "sensor.unnamed" in msg
# ---------------------------------------------------------------------------
# Adapter initialization from config
# ---------------------------------------------------------------------------
class TestAdapterInit:
def test_url_and_token_from_config_extra(self, monkeypatch):
monkeypatch.delenv("HASS_URL", raising=False)
monkeypatch.delenv("HASS_TOKEN", raising=False)
config = PlatformConfig(
enabled=True,
token="config-token",
extra={"url": "http://192.168.1.50:8123"},
)
adapter = HomeAssistantAdapter(config)
assert adapter._hass_token == "config-token"
assert adapter._hass_url == "http://192.168.1.50:8123"
def test_url_fallback_to_env(self, monkeypatch):
monkeypatch.setenv("HASS_URL", "http://env-host:8123")
monkeypatch.setenv("HASS_TOKEN", "env-tok")
config = PlatformConfig(enabled=True, token="env-tok")
adapter = HomeAssistantAdapter(config)
assert adapter._hass_url == "http://env-host:8123"
def test_trailing_slash_stripped(self):
config = PlatformConfig(
enabled=True, token="t",
extra={"url": "http://ha.local:8123/"},
)
adapter = HomeAssistantAdapter(config)
assert adapter._hass_url == "http://ha.local:8123"
def test_watch_filters_parsed(self):
config = PlatformConfig(
enabled=True, token="t",
extra={
"watch_domains": ["climate", "binary_sensor"],
"watch_entities": ["sensor.special"],
"ignore_entities": ["sensor.uptime", "sensor.cpu"],
"cooldown_seconds": 120,
},
)
adapter = HomeAssistantAdapter(config)
assert adapter._watch_domains == {"climate", "binary_sensor"}
assert adapter._watch_entities == {"sensor.special"}
assert adapter._ignore_entities == {"sensor.uptime", "sensor.cpu"}
assert adapter._cooldown_seconds == 120
def test_defaults_when_no_extra(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "tok")
config = PlatformConfig(enabled=True, token="tok")
adapter = HomeAssistantAdapter(config)
assert adapter._watch_domains == set()
assert adapter._watch_entities == set()
assert adapter._ignore_entities == set()
assert adapter._cooldown_seconds == 30
# ---------------------------------------------------------------------------
# Event filtering pipeline (_handle_ha_event)
#
# We mock handle_message (not our code, it's the base class pipeline) to
# capture the MessageEvent that _handle_ha_event produces.
# ---------------------------------------------------------------------------
def _make_adapter(**extra) -> HomeAssistantAdapter:
config = PlatformConfig(enabled=True, token="tok", extra=extra)
adapter = HomeAssistantAdapter(config)
adapter.handle_message = AsyncMock()
return adapter
def _make_event(entity_id, old_state, new_state, old_attrs=None, new_attrs=None):
return {
"data": {
"entity_id": entity_id,
"old_state": {"state": old_state, "attributes": old_attrs or {}},
"new_state": {"state": new_state, "attributes": new_attrs or {"friendly_name": entity_id}},
}
}
class TestEventFilteringPipeline:
@pytest.mark.asyncio
async def test_ignored_entity_not_forwarded(self):
adapter = _make_adapter(ignore_entities=["sensor.uptime"])
await adapter._handle_ha_event(_make_event("sensor.uptime", "100", "101"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_unwatched_domain_not_forwarded(self):
adapter = _make_adapter(watch_domains=["climate"])
await adapter._handle_ha_event(_make_event("light.bedroom", "off", "on"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_watched_domain_forwarded(self):
adapter = _make_adapter(watch_domains=["climate"], cooldown_seconds=0)
await adapter._handle_ha_event(
_make_event("climate.thermostat", "off", "heat",
new_attrs={"friendly_name": "Thermostat", "current_temperature": 20, "temperature": 22})
)
adapter.handle_message.assert_called_once()
# Verify the actual MessageEvent text content
msg_event = adapter.handle_message.call_args[0][0]
assert "Thermostat" in msg_event.text
assert "heat" in msg_event.text
assert msg_event.source.platform == Platform.HOMEASSISTANT
assert msg_event.source.chat_id == "ha_events"
@pytest.mark.asyncio
async def test_watched_entity_forwarded(self):
adapter = _make_adapter(watch_entities=["sensor.important"], cooldown_seconds=0)
await adapter._handle_ha_event(
_make_event("sensor.important", "10", "20",
new_attrs={"friendly_name": "Important Sensor", "unit_of_measurement": "W"})
)
adapter.handle_message.assert_called_once()
msg_event = adapter.handle_message.call_args[0][0]
assert "10W" in msg_event.text and "20W" in msg_event.text
@pytest.mark.asyncio
async def test_no_filters_passes_everything(self):
adapter = _make_adapter(cooldown_seconds=0)
await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open"))
adapter.handle_message.assert_called_once()
@pytest.mark.asyncio
async def test_same_state_not_forwarded(self):
adapter = _make_adapter(cooldown_seconds=0)
await adapter._handle_ha_event(_make_event("light.x", "on", "on"))
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_empty_entity_id_skipped(self):
adapter = _make_adapter()
await adapter._handle_ha_event({"data": {"entity_id": ""}})
adapter.handle_message.assert_not_called()
@pytest.mark.asyncio
async def test_message_event_has_correct_source(self):
adapter = _make_adapter(cooldown_seconds=0)
await adapter._handle_ha_event(
_make_event("light.test", "off", "on",
new_attrs={"friendly_name": "Test Light"})
)
msg_event = adapter.handle_message.call_args[0][0]
assert msg_event.source.user_name == "Home Assistant"
assert msg_event.source.chat_type == "channel"
assert msg_event.message_id.startswith("ha_light.test_")
# ---------------------------------------------------------------------------
# Cooldown behavior
# ---------------------------------------------------------------------------
class TestCooldown:
@pytest.mark.asyncio
async def test_cooldown_blocks_rapid_events(self):
adapter = _make_adapter(cooldown_seconds=60)
event = _make_event("sensor.temp", "20", "21",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event)
assert adapter.handle_message.call_count == 1
# Second event immediately after should be blocked
event2 = _make_event("sensor.temp", "21", "22",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event2)
assert adapter.handle_message.call_count == 1 # Still 1
@pytest.mark.asyncio
async def test_cooldown_expires(self):
adapter = _make_adapter(cooldown_seconds=1)
event = _make_event("sensor.temp", "20", "21",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event)
assert adapter.handle_message.call_count == 1
# Simulate time passing beyond cooldown
adapter._last_event_time["sensor.temp"] = time.time() - 2
event2 = _make_event("sensor.temp", "21", "22",
new_attrs={"friendly_name": "Temp"})
await adapter._handle_ha_event(event2)
assert adapter.handle_message.call_count == 2
@pytest.mark.asyncio
async def test_different_entities_independent_cooldowns(self):
adapter = _make_adapter(cooldown_seconds=60)
await adapter._handle_ha_event(
_make_event("sensor.a", "1", "2", new_attrs={"friendly_name": "A"})
)
await adapter._handle_ha_event(
_make_event("sensor.b", "3", "4", new_attrs={"friendly_name": "B"})
)
# Both should pass - different entities
assert adapter.handle_message.call_count == 2
# Same entity again - should be blocked
await adapter._handle_ha_event(
_make_event("sensor.a", "2", "3", new_attrs={"friendly_name": "A"})
)
assert adapter.handle_message.call_count == 2 # Still 2
@pytest.mark.asyncio
async def test_zero_cooldown_passes_all(self):
adapter = _make_adapter(cooldown_seconds=0)
for i in range(5):
await adapter._handle_ha_event(
_make_event("sensor.temp", str(i), str(i + 1),
new_attrs={"friendly_name": "Temp"})
)
assert adapter.handle_message.call_count == 5
# ---------------------------------------------------------------------------
# Config integration (env overrides, round-trip)
# ---------------------------------------------------------------------------
class TestConfigIntegration:
def test_env_override_creates_ha_platform(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "env-token")
monkeypatch.setenv("HASS_URL", "http://10.0.0.5:8123")
# Clear other platform tokens
for v in ["TELEGRAM_BOT_TOKEN", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
monkeypatch.delenv(v, raising=False)
from gateway.config import load_gateway_config
config = load_gateway_config()
assert Platform.HOMEASSISTANT in config.platforms
ha = config.platforms[Platform.HOMEASSISTANT]
assert ha.enabled is True
assert ha.token == "env-token"
assert ha.extra["url"] == "http://10.0.0.5:8123"
def test_no_env_no_platform(self, monkeypatch):
for v in ["HASS_TOKEN", "HASS_URL", "TELEGRAM_BOT_TOKEN",
"DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
monkeypatch.delenv(v, raising=False)
from gateway.config import load_gateway_config
config = load_gateway_config()
assert Platform.HOMEASSISTANT not in config.platforms
def test_config_roundtrip_preserves_extra(self):
config = GatewayConfig(
platforms={
Platform.HOMEASSISTANT: PlatformConfig(
enabled=True,
token="tok",
extra={
"url": "http://ha:8123",
"watch_domains": ["climate"],
"cooldown_seconds": 45,
},
),
},
)
d = config.to_dict()
restored = GatewayConfig.from_dict(d)
ha = restored.platforms[Platform.HOMEASSISTANT]
assert ha.enabled is True
assert ha.token == "tok"
assert ha.extra["watch_domains"] == ["climate"]
assert ha.extra["cooldown_seconds"] == 45
def test_connected_platforms_includes_ha(self):
config = GatewayConfig(
platforms={
Platform.HOMEASSISTANT: PlatformConfig(enabled=True, token="tok"),
Platform.TELEGRAM: PlatformConfig(enabled=False, token="t"),
},
)
connected = config.get_connected_platforms()
assert Platform.HOMEASSISTANT in connected
assert Platform.TELEGRAM not in connected
# ---------------------------------------------------------------------------
# send() via REST API
# ---------------------------------------------------------------------------
class TestSendViaRestApi:
"""send() uses REST API (not WebSocket) to avoid race conditions."""
@staticmethod
def _mock_aiohttp_session(response_status=200, response_text="OK"):
"""Build a mock aiohttp session + response for async-with patterns.
aiohttp.ClientSession() is a sync constructor whose return value
is used as ``async with session:``. ``session.post(...)`` returns a
context-manager (not a coroutine), so both layers use MagicMock for
the call and AsyncMock only for ``__aenter__`` / ``__aexit__``.
"""
mock_response = MagicMock()
mock_response.status = response_status
mock_response.text = AsyncMock(return_value=response_text)
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.post = MagicMock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
return mock_session
@pytest.mark.asyncio
async def test_send_success(self):
adapter = _make_adapter()
mock_session = self._mock_aiohttp_session(200)
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
result = await adapter.send("ha_events", "Test notification")
assert result.success is True
# Verify the REST API was called with correct payload
call_args = mock_session.post.call_args
assert "/api/services/persistent_notification/create" in call_args[0][0]
assert call_args[1]["json"]["title"] == "Hermes Agent"
assert call_args[1]["json"]["message"] == "Test notification"
assert "Bearer tok" in call_args[1]["headers"]["Authorization"]
@pytest.mark.asyncio
async def test_send_http_error(self):
adapter = _make_adapter()
mock_session = self._mock_aiohttp_session(401, "Unauthorized")
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
result = await adapter.send("ha_events", "Test")
assert result.success is False
assert "401" in result.error
@pytest.mark.asyncio
async def test_send_truncates_long_message(self):
adapter = _make_adapter()
mock_session = self._mock_aiohttp_session(200)
long_message = "x" * 10000
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
await adapter.send("ha_events", long_message)
sent_message = mock_session.post.call_args[1]["json"]["message"]
assert len(sent_message) == 4096
@pytest.mark.asyncio
async def test_send_does_not_use_websocket(self):
"""send() must use REST API, not the WS connection (race condition fix)."""
adapter = _make_adapter()
adapter._ws = AsyncMock() # Simulate an active WS
mock_session = self._mock_aiohttp_session(200)
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
mock_aiohttp.ClientTimeout = lambda total: total
await adapter.send("ha_events", "Test")
# WS should NOT have been used for sending
adapter._ws.send_json.assert_not_called()
adapter._ws.receive_json.assert_not_called()
# ---------------------------------------------------------------------------
# Toolset integration
# ---------------------------------------------------------------------------
class TestToolsetIntegration:
def test_homeassistant_toolset_resolves(self):
from toolsets import resolve_toolset
tools = resolve_toolset("homeassistant")
assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service"}
def test_gateway_toolset_includes_ha_tools(self):
from toolsets import resolve_toolset
gateway_tools = resolve_toolset("hermes-gateway")
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
assert tool in gateway_tools
def test_hermes_core_tools_includes_ha(self):
from toolsets import _HERMES_CORE_TOOLS
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
assert tool in _HERMES_CORE_TOOLS
# ---------------------------------------------------------------------------
# WebSocket URL construction
# ---------------------------------------------------------------------------
class TestWsUrlConstruction:
def test_http_to_ws(self):
config = PlatformConfig(enabled=True, token="t", extra={"url": "http://ha:8123"})
adapter = HomeAssistantAdapter(config)
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
assert ws_url == "ws://ha:8123"
def test_https_to_wss(self):
config = PlatformConfig(enabled=True, token="t", extra={"url": "https://ha.example.com"})
adapter = HomeAssistantAdapter(config)
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
assert ws_url == "wss://ha.example.com"

View file

@ -0,0 +1,341 @@
"""Integration tests for Home Assistant (tool + gateway).
Spins up a real in-process fake HA server (HTTP + WebSocket) and exercises
the full adapter and tool handler paths over real TCP connections.
No mocks -- only real async I/O against a fake server.
Run with: uv run pytest tests/integration/test_ha_integration.py -v
"""
import asyncio
import pytest
pytestmark = pytest.mark.integration
from unittest.mock import AsyncMock
from gateway.config import Platform, PlatformConfig
from gateway.platforms.homeassistant import HomeAssistantAdapter
from tests.fakes.fake_ha_server import FakeHAServer, ENTITY_STATES
from tools.homeassistant_tool import (
_async_call_service,
_async_get_state,
_async_list_entities,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _adapter_for(server: FakeHAServer, **extra) -> HomeAssistantAdapter:
"""Create an adapter pointed at the fake server."""
config = PlatformConfig(
enabled=True,
token=server.token,
extra={"url": server.url, **extra},
)
return HomeAssistantAdapter(config)
# ---------------------------------------------------------------------------
# 1. Gateway -- WebSocket lifecycle
# ---------------------------------------------------------------------------
class TestGatewayWebSocket:
@pytest.mark.asyncio
async def test_connect_auth_subscribe(self):
"""Full WS handshake succeeds: auth_required -> auth -> auth_ok -> subscribe -> ACK."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
connected = await adapter.connect()
assert connected is True
assert adapter._running is True
assert adapter._ws is not None
assert not adapter._ws.closed
await adapter.disconnect()
@pytest.mark.asyncio
async def test_connect_auth_rejected(self):
"""connect() returns False when the server rejects auth."""
async with FakeHAServer() as server:
server.reject_auth = True
adapter = _adapter_for(server)
connected = await adapter.connect()
assert connected is False
@pytest.mark.asyncio
async def test_event_received_and_forwarded(self):
"""Server pushes event -> adapter calls handle_message with correct MessageEvent."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
adapter.handle_message = AsyncMock()
await adapter.connect()
# Push a state_changed event
await server.push_event({
"data": {
"entity_id": "light.bedroom",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {"friendly_name": "Bedroom Light"},
},
}
})
# Wait for the adapter to process it
for _ in range(50):
if adapter.handle_message.call_count > 0:
break
await asyncio.sleep(0.05)
assert adapter.handle_message.call_count == 1
msg_event = adapter.handle_message.call_args[0][0]
assert "Bedroom Light" in msg_event.text
assert "turned on" in msg_event.text
assert msg_event.source.platform == Platform.HOMEASSISTANT
await adapter.disconnect()
@pytest.mark.asyncio
async def test_event_filtering_ignores_unwatched(self):
"""Events outside watch_domains are silently dropped."""
async with FakeHAServer() as server:
adapter = _adapter_for(server, watch_domains=["climate"])
adapter.handle_message = AsyncMock()
await adapter.connect()
# Push a light event (not in watch_domains)
await server.push_event({
"data": {
"entity_id": "light.bedroom",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {"friendly_name": "Bedroom Light"},
},
}
})
await asyncio.sleep(0.5)
assert adapter.handle_message.call_count == 0
await adapter.disconnect()
@pytest.mark.asyncio
async def test_disconnect_closes_cleanly(self):
"""disconnect() cancels listener and closes WebSocket."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
await adapter.connect()
ws_ref = adapter._ws
await adapter.disconnect()
assert adapter._running is False
assert adapter._listen_task is None
assert adapter._ws is None
# The original WS reference should be closed
assert ws_ref.closed
# ---------------------------------------------------------------------------
# 2. REST tool handlers (real HTTP against fake server)
# ---------------------------------------------------------------------------
class TestToolRest:
"""Call the async tool functions directly against the fake server.
Note: we call ``_async_*`` instead of the sync ``_handle_*`` wrappers
because the sync wrappers use ``_run_async`` which blocks the event
loop, deadlocking with the in-process fake server. The async functions
are the real logic; the sync wrappers are trivial bridge code already
covered by unit tests.
"""
@pytest.mark.asyncio
async def test_list_entities_returns_all(self, monkeypatch):
"""_async_list_entities returns all entities from the fake server."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_list_entities()
assert result["count"] == len(ENTITY_STATES)
ids = {e["entity_id"] for e in result["entities"]}
assert "light.bedroom" in ids
assert "climate.thermostat" in ids
@pytest.mark.asyncio
async def test_list_entities_domain_filter(self, monkeypatch):
"""Domain filter is applied after fetching from server."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_list_entities(domain="light")
assert result["count"] == 2
for e in result["entities"]:
assert e["entity_id"].startswith("light.")
@pytest.mark.asyncio
async def test_get_state_single_entity(self, monkeypatch):
"""_async_get_state returns full entity details."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_get_state("light.bedroom")
assert result["entity_id"] == "light.bedroom"
assert result["state"] == "on"
assert result["attributes"]["brightness"] == 200
assert result["last_changed"] is not None
@pytest.mark.asyncio
async def test_get_state_not_found(self, monkeypatch):
"""Non-existent entity raises an aiohttp error (404)."""
import aiohttp as _aiohttp
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
await _async_get_state("light.nonexistent")
assert exc_info.value.status == 404
@pytest.mark.asyncio
async def test_call_service_turn_on(self, monkeypatch):
"""_async_call_service sends correct payload and server records it."""
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
result = await _async_call_service(
domain="light",
service="turn_on",
entity_id="light.bedroom",
data={"brightness": 255},
)
assert result["success"] is True
assert result["service"] == "light.turn_on"
assert len(result["affected_entities"]) == 1
assert result["affected_entities"][0]["state"] == "on"
# Verify fake server recorded the call
assert len(server.received_service_calls) == 1
call = server.received_service_calls[0]
assert call["domain"] == "light"
assert call["service"] == "turn_on"
assert call["data"]["entity_id"] == "light.bedroom"
assert call["data"]["brightness"] == 255
# ---------------------------------------------------------------------------
# 3. send() -- REST notification
# ---------------------------------------------------------------------------
class TestSendNotification:
@pytest.mark.asyncio
async def test_send_notification_delivered(self):
"""Adapter send() delivers notification to fake server REST endpoint."""
async with FakeHAServer() as server:
adapter = _adapter_for(server)
result = await adapter.send("ha_events", "Test notification from agent")
assert result.success is True
assert len(server.received_notifications) == 1
notif = server.received_notifications[0]
assert notif["title"] == "Hermes Agent"
assert notif["message"] == "Test notification from agent"
@pytest.mark.asyncio
async def test_send_auth_failure(self):
"""send() returns failure when token is wrong."""
async with FakeHAServer() as server:
config = PlatformConfig(
enabled=True,
token="wrong-token",
extra={"url": server.url},
)
adapter = HomeAssistantAdapter(config)
result = await adapter.send("ha_events", "Should fail")
assert result.success is False
assert "401" in result.error
# ---------------------------------------------------------------------------
# 4. Auth and error cases
# ---------------------------------------------------------------------------
class TestAuthAndErrors:
@pytest.mark.asyncio
async def test_rest_unauthorized(self, monkeypatch):
"""Async function raises on 401 when token is wrong."""
import aiohttp as _aiohttp
async with FakeHAServer() as server:
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", "bad-token",
)
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
await _async_list_entities()
assert exc_info.value.status == 401
@pytest.mark.asyncio
async def test_rest_server_error(self, monkeypatch):
"""Async function raises on 500 response."""
import aiohttp as _aiohttp
async with FakeHAServer() as server:
server.force_500 = True
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_URL", server.url,
)
monkeypatch.setattr(
"tools.homeassistant_tool._HASS_TOKEN", server.token,
)
with pytest.raises(_aiohttp.ClientResponseError) as exc_info:
await _async_list_entities()
assert exc_info.value.status == 500

View file

@ -546,6 +546,24 @@ class TestBuildAssistantMessage:
result = agent._build_assistant_message(msg, "stop")
assert result["content"] == ""
def test_tool_call_extra_content_preserved(self, agent):
"""Gemini thinking models attach extra_content with thought_signature
to tool calls. This must be preserved so subsequent API calls include it."""
tc = _mock_tool_call(name="get_weather", arguments='{"city":"NYC"}', call_id="c2")
tc.extra_content = {"google": {"thought_signature": "abc123"}}
msg = _mock_assistant_msg(content="", tool_calls=[tc])
result = agent._build_assistant_message(msg, "tool_calls")
assert result["tool_calls"][0]["extra_content"] == {
"google": {"thought_signature": "abc123"}
}
def test_tool_call_without_extra_content(self, agent):
"""Standard tool calls (no thinking model) should not have extra_content."""
tc = _mock_tool_call(name="web_search", arguments='{}', call_id="c3")
msg = _mock_assistant_msg(content="", tool_calls=[tc])
result = agent._build_assistant_message(msg, "tool_calls")
assert "extra_content" not in result["tool_calls"][0]
class TestFormatToolsForSystemMessage:
def test_no_tools_returns_empty_array(self, agent):

View file

@ -0,0 +1,373 @@
"""Tests for the Home Assistant tool module.
Tests real logic: entity filtering, payload building, response parsing,
handler validation, and availability gating.
"""
import json
import pytest
from tools.homeassistant_tool import (
_check_ha_available,
_filter_and_summarize,
_build_service_payload,
_parse_service_response,
_get_headers,
_handle_get_state,
_handle_call_service,
_BLOCKED_DOMAINS,
_ENTITY_ID_RE,
)
# ---------------------------------------------------------------------------
# Sample HA state data (matches real HA /api/states response shape)
# ---------------------------------------------------------------------------
SAMPLE_STATES = [
{"entity_id": "light.bedroom", "state": "on", "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}},
{"entity_id": "light.kitchen", "state": "off", "attributes": {"friendly_name": "Kitchen Light"}},
{"entity_id": "switch.fan", "state": "on", "attributes": {"friendly_name": "Living Room Fan"}},
{"entity_id": "sensor.temperature", "state": "22.5", "attributes": {"friendly_name": "Kitchen Temperature", "unit_of_measurement": "C"}},
{"entity_id": "climate.thermostat", "state": "heat", "attributes": {"friendly_name": "Main Thermostat", "current_temperature": 21}},
{"entity_id": "binary_sensor.motion", "state": "off", "attributes": {"friendly_name": "Hallway Motion"}},
{"entity_id": "sensor.humidity", "state": "55", "attributes": {"friendly_name": "Bedroom Humidity", "area": "bedroom"}},
]
# ---------------------------------------------------------------------------
# Entity filtering and summarization
# ---------------------------------------------------------------------------
class TestFilterAndSummarize:
def test_no_filters_returns_all(self):
result = _filter_and_summarize(SAMPLE_STATES)
assert result["count"] == 7
ids = {e["entity_id"] for e in result["entities"]}
assert "light.bedroom" in ids
assert "climate.thermostat" in ids
def test_domain_filter_lights(self):
result = _filter_and_summarize(SAMPLE_STATES, domain="light")
assert result["count"] == 2
for e in result["entities"]:
assert e["entity_id"].startswith("light.")
def test_domain_filter_sensor(self):
result = _filter_and_summarize(SAMPLE_STATES, domain="sensor")
assert result["count"] == 2
ids = {e["entity_id"] for e in result["entities"]}
assert ids == {"sensor.temperature", "sensor.humidity"}
def test_domain_filter_no_matches(self):
result = _filter_and_summarize(SAMPLE_STATES, domain="media_player")
assert result["count"] == 0
assert result["entities"] == []
def test_area_filter_by_friendly_name(self):
result = _filter_and_summarize(SAMPLE_STATES, area="kitchen")
assert result["count"] == 2
ids = {e["entity_id"] for e in result["entities"]}
assert "light.kitchen" in ids
assert "sensor.temperature" in ids
def test_area_filter_by_area_attribute(self):
result = _filter_and_summarize(SAMPLE_STATES, area="bedroom")
ids = {e["entity_id"] for e in result["entities"]}
# "Bedroom Light" matches via friendly_name, "Bedroom Humidity" matches via area attr
assert "light.bedroom" in ids
assert "sensor.humidity" in ids
def test_area_filter_case_insensitive(self):
result = _filter_and_summarize(SAMPLE_STATES, area="KITCHEN")
assert result["count"] == 2
def test_combined_domain_and_area(self):
result = _filter_and_summarize(SAMPLE_STATES, domain="sensor", area="kitchen")
assert result["count"] == 1
assert result["entities"][0]["entity_id"] == "sensor.temperature"
def test_summary_includes_friendly_name(self):
result = _filter_and_summarize(SAMPLE_STATES, domain="climate")
assert result["entities"][0]["friendly_name"] == "Main Thermostat"
assert result["entities"][0]["state"] == "heat"
def test_empty_states_list(self):
result = _filter_and_summarize([])
assert result["count"] == 0
def test_missing_attributes_handled(self):
states = [{"entity_id": "light.x", "state": "on"}]
result = _filter_and_summarize(states)
assert result["count"] == 1
assert result["entities"][0]["friendly_name"] == ""
# ---------------------------------------------------------------------------
# Service payload building
# ---------------------------------------------------------------------------
class TestBuildServicePayload:
def test_entity_id_only(self):
payload = _build_service_payload(entity_id="light.bedroom")
assert payload == {"entity_id": "light.bedroom"}
def test_data_only(self):
payload = _build_service_payload(data={"brightness": 255})
assert payload == {"brightness": 255}
def test_entity_id_and_data(self):
payload = _build_service_payload(
entity_id="light.bedroom",
data={"brightness": 200, "color_name": "blue"},
)
assert payload["entity_id"] == "light.bedroom"
assert payload["brightness"] == 200
assert payload["color_name"] == "blue"
def test_no_args_returns_empty(self):
payload = _build_service_payload()
assert payload == {}
def test_entity_id_param_takes_precedence_over_data(self):
payload = _build_service_payload(
entity_id="light.a",
data={"entity_id": "light.b"},
)
# explicit entity_id parameter wins over data["entity_id"]
assert payload["entity_id"] == "light.a"
# ---------------------------------------------------------------------------
# Service response parsing
# ---------------------------------------------------------------------------
class TestParseServiceResponse:
def test_list_response_extracts_entities(self):
ha_response = [
{"entity_id": "light.bedroom", "state": "on", "attributes": {}},
{"entity_id": "light.kitchen", "state": "on", "attributes": {}},
]
result = _parse_service_response("light", "turn_on", ha_response)
assert result["success"] is True
assert result["service"] == "light.turn_on"
assert len(result["affected_entities"]) == 2
assert result["affected_entities"][0]["entity_id"] == "light.bedroom"
def test_empty_list_response(self):
result = _parse_service_response("scene", "turn_on", [])
assert result["success"] is True
assert result["affected_entities"] == []
def test_non_list_response(self):
# Some HA services return a dict instead of a list
result = _parse_service_response("script", "run", {"result": "ok"})
assert result["success"] is True
assert result["affected_entities"] == []
def test_none_response(self):
result = _parse_service_response("automation", "trigger", None)
assert result["success"] is True
assert result["affected_entities"] == []
def test_service_name_format(self):
result = _parse_service_response("climate", "set_temperature", [])
assert result["service"] == "climate.set_temperature"
# ---------------------------------------------------------------------------
# Handler validation (no mocks - these paths don't reach the network)
# ---------------------------------------------------------------------------
class TestHandlerValidation:
def test_get_state_missing_entity_id(self):
result = json.loads(_handle_get_state({}))
assert "error" in result
assert "entity_id" in result["error"]
def test_get_state_empty_entity_id(self):
result = json.loads(_handle_get_state({"entity_id": ""}))
assert "error" in result
def test_call_service_missing_domain(self):
result = json.loads(_handle_call_service({"service": "turn_on"}))
assert "error" in result
assert "domain" in result["error"]
def test_call_service_missing_service(self):
result = json.loads(_handle_call_service({"domain": "light"}))
assert "error" in result
assert "service" in result["error"]
def test_call_service_missing_both(self):
result = json.loads(_handle_call_service({}))
assert "error" in result
def test_call_service_empty_strings(self):
result = json.loads(_handle_call_service({"domain": "", "service": ""}))
assert "error" in result
# ---------------------------------------------------------------------------
# Security: domain blocklist
# ---------------------------------------------------------------------------
class TestDomainBlocklist:
"""Verify dangerous HA service domains are blocked."""
@pytest.mark.parametrize("domain", sorted(_BLOCKED_DOMAINS))
def test_blocked_domain_rejected(self, domain):
result = json.loads(_handle_call_service({
"domain": domain, "service": "any_service"
}))
assert "error" in result
assert "blocked" in result["error"].lower()
def test_safe_domain_not_blocked(self):
"""Safe domains like 'light' should not be blocked (will fail on network, not blocklist)."""
# This will try to make a real HTTP call and fail, but the important thing
# is it does NOT return a "blocked" error
result = json.loads(_handle_call_service({
"domain": "light", "service": "turn_on", "entity_id": "light.test"
}))
# Should fail with a network/connection error, not a "blocked" error
if "error" in result:
assert "blocked" not in result["error"].lower()
def test_blocked_domains_include_shell_command(self):
assert "shell_command" in _BLOCKED_DOMAINS
def test_blocked_domains_include_hassio(self):
assert "hassio" in _BLOCKED_DOMAINS
def test_blocked_domains_include_rest_command(self):
assert "rest_command" in _BLOCKED_DOMAINS
# ---------------------------------------------------------------------------
# Security: entity_id validation
# ---------------------------------------------------------------------------
class TestEntityIdValidation:
"""Verify entity_id format validation prevents path traversal."""
def test_valid_entity_id_accepted(self):
assert _ENTITY_ID_RE.match("light.bedroom")
assert _ENTITY_ID_RE.match("sensor.temperature_1")
assert _ENTITY_ID_RE.match("binary_sensor.motion")
assert _ENTITY_ID_RE.match("climate.main_thermostat")
def test_path_traversal_rejected(self):
assert _ENTITY_ID_RE.match("../../config") is None
assert _ENTITY_ID_RE.match("light/../../../etc/passwd") is None
assert _ENTITY_ID_RE.match("../api/config") is None
def test_special_chars_rejected(self):
assert _ENTITY_ID_RE.match("light.bed room") is None # space
assert _ENTITY_ID_RE.match("light.bed;rm -rf") is None # semicolon
assert _ENTITY_ID_RE.match("light.bed/room") is None # slash
assert _ENTITY_ID_RE.match("LIGHT.BEDROOM") is None # uppercase
def test_missing_domain_rejected(self):
assert _ENTITY_ID_RE.match(".bedroom") is None
assert _ENTITY_ID_RE.match("bedroom") is None
def test_get_state_rejects_invalid_entity_id(self):
result = json.loads(_handle_get_state({"entity_id": "../../config"}))
assert "error" in result
assert "Invalid entity_id" in result["error"]
def test_call_service_rejects_invalid_entity_id(self):
result = json.loads(_handle_call_service({
"domain": "light",
"service": "turn_on",
"entity_id": "../../../etc/passwd",
}))
assert "error" in result
assert "Invalid entity_id" in result["error"]
def test_call_service_allows_no_entity_id(self):
"""Some services (like scene.turn_on) don't need entity_id."""
# Will fail on network, but should NOT fail on entity_id validation
result = json.loads(_handle_call_service({
"domain": "scene", "service": "turn_on"
}))
if "error" in result:
assert "Invalid entity_id" not in result["error"]
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------
class TestCheckAvailable:
def test_unavailable_without_token(self, monkeypatch):
monkeypatch.delenv("HASS_TOKEN", raising=False)
assert _check_ha_available() is False
def test_available_with_token(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "eyJ0eXAiOiJKV1Q")
assert _check_ha_available() is True
def test_empty_token_is_unavailable(self, monkeypatch):
monkeypatch.setenv("HASS_TOKEN", "")
assert _check_ha_available() is False
# ---------------------------------------------------------------------------
# Auth headers
# ---------------------------------------------------------------------------
class TestGetHeaders:
def test_bearer_token_format(self, monkeypatch):
monkeypatch.setattr("tools.homeassistant_tool._HASS_TOKEN", "my-secret-token")
headers = _get_headers()
assert headers["Authorization"] == "Bearer my-secret-token"
assert headers["Content-Type"] == "application/json"
# ---------------------------------------------------------------------------
# Registry integration
# ---------------------------------------------------------------------------
class TestRegistration:
def test_tools_registered_in_registry(self):
from tools.registry import registry
names = registry.get_all_tool_names()
assert "ha_list_entities" in names
assert "ha_get_state" in names
assert "ha_call_service" in names
def test_tools_in_homeassistant_toolset(self):
from tools.registry import registry
toolset_map = registry.get_tool_to_toolset_map()
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
assert toolset_map[tool] == "homeassistant"
def test_check_fn_gates_availability(self, monkeypatch):
"""Registry should exclude HA tools when HASS_TOKEN is not set."""
from tools.registry import registry
monkeypatch.delenv("HASS_TOKEN", raising=False)
defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
assert len(defs) == 0
def test_check_fn_includes_when_token_set(self, monkeypatch):
"""Registry should include HA tools when HASS_TOKEN is set."""
from tools.registry import registry
monkeypatch.setenv("HASS_TOKEN", "test-token")
defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
assert len(defs) == 3

392
tools/homeassistant_tool.py Normal file
View file

@ -0,0 +1,392 @@
"""Home Assistant tool for controlling smart home devices via REST API.
Registers three LLM-callable tools:
- ``ha_list_entities`` -- list/filter entities by domain or area
- ``ha_get_state`` -- get detailed state of a single entity
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
"""
import asyncio
import json
import logging
import os
import re
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
# Service domains blocked for security -- these allow arbitrary code/command
# execution on the HA host or enable SSRF attacks on the local network.
# HA provides zero service-level access control; all safety must be in our layer.
_BLOCKED_DOMAINS = frozenset({
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
})
def _get_headers() -> Dict[str, str]:
"""Return authorization headers for HA REST API."""
return {
"Authorization": f"Bearer {_HASS_TOKEN}",
"Content-Type": "application/json",
}
# ---------------------------------------------------------------------------
# Async helpers (called from sync handlers via run_until_complete)
# ---------------------------------------------------------------------------
def _filter_and_summarize(
states: list,
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
"""Filter raw HA states by domain/area and return a compact summary."""
if domain:
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
if area:
area_lower = area.lower()
states = [
s for s in states
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
]
entities = []
for s in states:
entities.append({
"entity_id": s["entity_id"],
"state": s["state"],
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
})
return {"count": len(entities), "entities": entities}
async def _async_list_entities(
domain: Optional[str] = None,
area: Optional[str] = None,
) -> Dict[str, Any]:
"""Fetch entity states from HA and optionally filter by domain/area."""
import aiohttp
url = f"{_HASS_URL}/api/states"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp:
resp.raise_for_status()
states = await resp.json()
return _filter_and_summarize(states, domain, area)
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
"""Fetch detailed state of a single entity."""
import aiohttp
url = f"{_HASS_URL}/api/states/{entity_id}"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp:
resp.raise_for_status()
data = await resp.json()
return {
"entity_id": data["entity_id"],
"state": data["state"],
"attributes": data.get("attributes", {}),
"last_changed": data.get("last_changed"),
"last_updated": data.get("last_updated"),
}
def _build_service_payload(
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Build the JSON payload for a HA service call."""
payload: Dict[str, Any] = {}
if data:
payload.update(data)
# entity_id parameter takes precedence over data["entity_id"]
if entity_id:
payload["entity_id"] = entity_id
return payload
def _parse_service_response(
domain: str,
service: str,
result: Any,
) -> Dict[str, Any]:
"""Parse HA service call response into a structured result."""
affected = []
if isinstance(result, list):
for s in result:
affected.append({
"entity_id": s.get("entity_id", ""),
"state": s.get("state", ""),
})
return {
"success": True,
"service": f"{domain}.{service}",
"affected_entities": affected,
}
async def _async_call_service(
domain: str,
service: str,
entity_id: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Call a Home Assistant service."""
import aiohttp
url = f"{_HASS_URL}/api/services/{domain}/{service}"
payload = _build_service_payload(entity_id, data)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=_get_headers(),
json=payload,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
resp.raise_for_status()
result = await resp.json()
return _parse_service_response(domain, service, result)
# ---------------------------------------------------------------------------
# Sync wrappers (handler signature: (args, **kw) -> str)
# ---------------------------------------------------------------------------
def _run_async(coro):
"""Run an async coroutine from a sync handler."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# Already inside an event loop -- create a new thread
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
future = pool.submit(asyncio.run, coro)
return future.result(timeout=30)
else:
return asyncio.run(coro)
def _handle_list_entities(args: dict, **kw) -> str:
"""Handler for ha_list_entities tool."""
domain = args.get("domain")
area = args.get("area")
try:
result = _run_async(_async_list_entities(domain=domain, area=area))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_list_entities error: %s", e)
return json.dumps({"error": f"Failed to list entities: {e}"})
def _handle_get_state(args: dict, **kw) -> str:
"""Handler for ha_get_state tool."""
entity_id = args.get("entity_id", "")
if not entity_id:
return json.dumps({"error": "Missing required parameter: entity_id"})
if not _ENTITY_ID_RE.match(entity_id):
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
try:
result = _run_async(_async_get_state(entity_id))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_get_state error: %s", e)
return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"})
def _handle_call_service(args: dict, **kw) -> str:
"""Handler for ha_call_service tool."""
domain = args.get("domain", "")
service = args.get("service", "")
if not domain or not service:
return json.dumps({"error": "Missing required parameters: domain and service"})
if domain in _BLOCKED_DOMAINS:
return json.dumps({
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
})
entity_id = args.get("entity_id")
if entity_id and not _ENTITY_ID_RE.match(entity_id):
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
data = args.get("data")
try:
result = _run_async(_async_call_service(domain, service, entity_id, data))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_call_service error: %s", e)
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------
def _check_ha_available() -> bool:
"""Tool is only available when HASS_TOKEN is set."""
return bool(os.getenv("HASS_TOKEN"))
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
HA_LIST_ENTITIES_SCHEMA = {
"name": "ha_list_entities",
"description": (
"List Home Assistant entities. Optionally filter by domain "
"(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
"or by area name (living room, kitchen, bedroom, etc.)."
),
"parameters": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": (
"Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
"'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
"Omit to list all entities."
),
},
"area": {
"type": "string",
"description": (
"Area/room name to filter by (e.g. 'living room', 'kitchen'). "
"Matches against entity friendly names. Omit to list all."
),
},
},
"required": [],
},
}
HA_GET_STATE_SCHEMA = {
"name": "ha_get_state",
"description": (
"Get the detailed state of a single Home Assistant entity, including all "
"attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
),
"parameters": {
"type": "object",
"properties": {
"entity_id": {
"type": "string",
"description": (
"The entity ID to query (e.g. 'light.living_room', "
"'climate.thermostat', 'sensor.temperature')."
),
},
},
"required": ["entity_id"],
},
}
HA_CALL_SERVICE_SCHEMA = {
"name": "ha_call_service",
"description": (
"Call a Home Assistant service to control a device. Common examples: "
"turn_on/turn_off lights and switches, set_temperature for climate, "
"open_cover/close_cover for blinds, set_volume_level for media players."
),
"parameters": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": (
"Service domain (e.g. 'light', 'switch', 'climate', "
"'cover', 'media_player', 'fan', 'scene', 'script')."
),
},
"service": {
"type": "string",
"description": (
"Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
"'set_temperature', 'set_hvac_mode', 'open_cover', "
"'close_cover', 'set_volume_level')."
),
},
"entity_id": {
"type": "string",
"description": (
"Target entity ID (e.g. 'light.living_room'). "
"Some services (like scene.turn_on) may not need this."
),
},
"data": {
"type": "object",
"description": (
"Additional service data. Examples: "
'{"brightness": 255, "color_name": "blue"} for lights, '
'{"temperature": 22, "hvac_mode": "heat"} for climate, '
'{"volume_level": 0.5} for media players.'
),
},
},
"required": ["domain", "service"],
},
}
# ---------------------------------------------------------------------------
# Registration
# ---------------------------------------------------------------------------
from tools.registry import registry
registry.register(
name="ha_list_entities",
toolset="homeassistant",
schema=HA_LIST_ENTITIES_SCHEMA,
handler=_handle_list_entities,
check_fn=_check_ha_available,
)
registry.register(
name="ha_get_state",
toolset="homeassistant",
schema=HA_GET_STATE_SCHEMA,
handler=_handle_get_state,
check_fn=_check_ha_available,
)
registry.register(
name="ha_call_service",
toolset="homeassistant",
schema=HA_CALL_SERVICE_SCHEMA,
handler=_handle_call_service,
check_fn=_check_ha_available,
)

View file

@ -62,6 +62,8 @@ _HERMES_CORE_TOOLS = [
"send_message",
# Honcho user context (gated on honcho being active via check_fn)
"query_user_context",
# Home Assistant smart home control (gated on HASS_TOKEN via check_fn)
"ha_list_entities", "ha_get_state", "ha_call_service",
]
@ -194,6 +196,12 @@ TOOLSETS = {
"includes": []
},
"homeassistant": {
"description": "Home Assistant smart home control and monitoring",
"tools": ["ha_list_entities", "ha_get_state", "ha_call_service"],
"includes": []
},
# Scenario-specific toolsets
@ -247,10 +255,16 @@ TOOLSETS = {
"includes": []
},
"hermes-homeassistant": {
"description": "Home Assistant bot toolset - smart home event monitoring and control",
"tools": _HERMES_CORE_TOOLS,
"includes": []
},
"hermes-gateway": {
"description": "Gateway toolset - union of all messaging platform tools",
"tools": [],
"includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack"]
"includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-homeassistant"]
}
}

7
uv.lock generated
View file

@ -1035,6 +1035,9 @@ dev = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
]
homeassistant = [
{ name = "aiohttp" },
]
honcho = [
{ name = "honcho-ai" },
]
@ -1064,6 +1067,7 @@ tts-premium = [
[package.metadata]
requires-dist = [
{ name = "aiohttp", marker = "extra == 'homeassistant'", specifier = ">=3.9.0" },
{ name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.9.0" },
{ name = "croniter", marker = "extra == 'cron'" },
{ name = "discord-py", marker = "extra == 'messaging'", specifier = ">=2.0" },
@ -1075,6 +1079,7 @@ requires-dist = [
{ name = "hermes-agent", extras = ["cli"], marker = "extra == 'all'" },
{ name = "hermes-agent", extras = ["cron"], marker = "extra == 'all'" },
{ name = "hermes-agent", extras = ["dev"], marker = "extra == 'all'" },
{ name = "hermes-agent", extras = ["homeassistant"], marker = "extra == 'all'" },
{ name = "hermes-agent", extras = ["honcho"], marker = "extra == 'all'" },
{ name = "hermes-agent", extras = ["mcp"], marker = "extra == 'all'" },
{ name = "hermes-agent", extras = ["messaging"], marker = "extra == 'all'" },
@ -1109,7 +1114,7 @@ requires-dist = [
{ name = "tenacity" },
{ name = "typer" },
]
provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "mcp", "all"]
provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "mcp", "homeassistant", "all"]
[[package]]
name = "hf-xet"