From c36b256de56ae97e3ccabe8e97a02ae31a371e3d Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Sat, 28 Feb 2026 13:32:48 +0300 Subject: [PATCH 1/5] feat: add Home Assistant integration (REST tools + WebSocket gateway) - Add ha_list_entities, ha_get_state, ha_call_service tools via REST API - Add WebSocket gateway adapter for real-time state_changed event monitoring - Support domain/entity filtering, cooldown, and auto-reconnect with backoff - Use REST API for outbound notifications to avoid WS race condition - Gate tool availability on HASS_TOKEN env var - Add 82 unit tests covering real logic (filtering, payload building, event pipeline) --- gateway/config.py | 12 + gateway/platforms/homeassistant.py | 413 +++++++++++++++++ gateway/run.py | 9 +- model_tools.py | 1 + pyproject.toml | 2 + tests/gateway/test_homeassistant.py | 604 +++++++++++++++++++++++++ tests/tools/test_homeassistant_tool.py | 281 ++++++++++++ tools/homeassistant_tool.py | 364 +++++++++++++++ toolsets.py | 20 +- uv.lock | 7 +- 10 files changed, 1708 insertions(+), 5 deletions(-) create mode 100644 gateway/platforms/homeassistant.py create mode 100644 tests/gateway/test_homeassistant.py create mode 100644 tests/tools/test_homeassistant_tool.py create mode 100644 tools/homeassistant_tool.py diff --git a/gateway/config.py b/gateway/config.py index 32b623ea..f441e2dd 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -26,6 +26,7 @@ class Platform(Enum): DISCORD = "discord" WHATSAPP = "whatsapp" SLACK = "slack" + HOMEASSISTANT = "homeassistant" @dataclass @@ -378,6 +379,17 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""), ) + # Home Assistant + hass_token = os.getenv("HASS_TOKEN") + if hass_token: + if Platform.HOMEASSISTANT not in config.platforms: + config.platforms[Platform.HOMEASSISTANT] = PlatformConfig() + config.platforms[Platform.HOMEASSISTANT].enabled = True + config.platforms[Platform.HOMEASSISTANT].token = hass_token + hass_url = os.getenv("HASS_URL") + if hass_url: + config.platforms[Platform.HOMEASSISTANT].extra["url"] = hass_url + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/platforms/homeassistant.py b/gateway/platforms/homeassistant.py new file mode 100644 index 00000000..749cdf1e --- /dev/null +++ b/gateway/platforms/homeassistant.py @@ -0,0 +1,413 @@ +""" +Home Assistant platform adapter. + +Connects to the HA WebSocket API for real-time event monitoring. +State-change events are converted to MessageEvent objects and forwarded +to the agent for processing. Outbound messages are delivered as HA +persistent notifications. + +Requires: +- aiohttp (already in messaging extras) +- HASS_TOKEN env var (Long-Lived Access Token) +- HASS_URL env var (default: http://homeassistant.local:8123) +""" + +import asyncio +import json +import logging +import os +import time +from datetime import datetime +from typing import Any, Dict, List, Optional, Set + +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + aiohttp = None # type: ignore[assignment] + +import sys +from pathlib import Path as _Path +sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, +) + +logger = logging.getLogger(__name__) + + +def check_ha_requirements() -> bool: + """Check if Home Assistant dependencies are available and configured.""" + if not AIOHTTP_AVAILABLE: + return False + if not os.getenv("HASS_TOKEN"): + return False + return True + + +class HomeAssistantAdapter(BasePlatformAdapter): + """ + Home Assistant WebSocket adapter. + + Subscribes to ``state_changed`` events and forwards them as + MessageEvent objects. Supports domain/entity filtering and + per-entity cooldowns to avoid event floods. + """ + + MAX_MESSAGE_LENGTH = 4096 + + # Reconnection backoff schedule (seconds) + _BACKOFF_STEPS = [5, 10, 30, 60] + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.HOMEASSISTANT) + + # Connection state + self._session: Optional["aiohttp.ClientSession"] = None + self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None + self._listen_task: Optional[asyncio.Task] = None + self._msg_id: int = 0 + + # Configuration from extra + extra = config.extra or {} + token = config.token or os.getenv("HASS_TOKEN", "") + url = extra.get("url") or os.getenv("HASS_URL", "http://homeassistant.local:8123") + self._hass_url: str = url.rstrip("/") + self._hass_token: str = token + + # Event filtering + self._watch_domains: Set[str] = set(extra.get("watch_domains", [])) + self._watch_entities: Set[str] = set(extra.get("watch_entities", [])) + self._ignore_entities: Set[str] = set(extra.get("ignore_entities", [])) + self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30)) + + # Cooldown tracking: entity_id -> last_event_timestamp + self._last_event_time: Dict[str, float] = {} + + def _next_id(self) -> int: + """Return the next WebSocket message ID.""" + self._msg_id += 1 + return self._msg_id + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Connect to HA WebSocket API and subscribe to events.""" + if not AIOHTTP_AVAILABLE: + print(f"[{self.name}] aiohttp not installed. Run: pip install aiohttp") + return False + + if not self._hass_token: + print(f"[{self.name}] No HASS_TOKEN configured") + return False + + try: + success = await self._ws_connect() + if not success: + return False + + # Start background listener + self._listen_task = asyncio.create_task(self._listen_loop()) + self._running = True + print(f"[{self.name}] Connected to {self._hass_url}") + return True + + except Exception as e: + print(f"[{self.name}] Failed to connect: {e}") + return False + + async def _ws_connect(self) -> bool: + """Establish WebSocket connection and authenticate.""" + ws_url = self._hass_url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url}/api/websocket" + + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect(ws_url, heartbeat=30) + + # Step 1: Receive auth_required + msg = await self._ws.receive_json() + if msg.get("type") != "auth_required": + logger.error("Expected auth_required, got: %s", msg.get("type")) + await self._cleanup_ws() + return False + + # Step 2: Send auth + await self._ws.send_json({ + "type": "auth", + "access_token": self._hass_token, + }) + + # Step 3: Wait for auth_ok + msg = await self._ws.receive_json() + if msg.get("type") != "auth_ok": + logger.error("Auth failed: %s", msg) + await self._cleanup_ws() + return False + + # Step 4: Subscribe to state_changed events + sub_id = self._next_id() + await self._ws.send_json({ + "id": sub_id, + "type": "subscribe_events", + "event_type": "state_changed", + }) + + # Verify subscription acknowledgement + msg = await self._ws.receive_json() + if not msg.get("success"): + logger.error("Failed to subscribe to events: %s", msg) + await self._cleanup_ws() + return False + + return True + + async def _cleanup_ws(self) -> None: + """Close WebSocket and session.""" + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def disconnect(self) -> None: + """Disconnect from Home Assistant.""" + self._running = False + if self._listen_task: + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + pass + self._listen_task = None + + await self._cleanup_ws() + print(f"[{self.name}] Disconnected") + + # ------------------------------------------------------------------ + # Event listener + # ------------------------------------------------------------------ + + async def _listen_loop(self) -> None: + """Main event loop with automatic reconnection.""" + backoff_idx = 0 + + while self._running: + try: + await self._read_events() + except asyncio.CancelledError: + return + except Exception as e: + logger.warning("[%s] WebSocket error: %s", self.name, e) + + if not self._running: + return + + # Reconnect with backoff + delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)] + print(f"[{self.name}] Reconnecting in {delay}s...") + await asyncio.sleep(delay) + backoff_idx += 1 + + try: + await self._cleanup_ws() + success = await self._ws_connect() + if success: + backoff_idx = 0 # Reset on successful reconnect + print(f"[{self.name}] Reconnected") + except Exception as e: + logger.warning("[%s] Reconnection failed: %s", self.name, e) + + async def _read_events(self) -> None: + """Read events from WebSocket until disconnected.""" + async for ws_msg in self._ws: + if ws_msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(ws_msg.data) + if data.get("type") == "event": + await self._handle_ha_event(data.get("event", {})) + except json.JSONDecodeError: + logger.debug("Invalid JSON from HA WS: %s", ws_msg.data[:200]) + elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): + break + + async def _handle_ha_event(self, event: Dict[str, Any]) -> None: + """Process a state_changed event from Home Assistant.""" + event_data = event.get("data", {}) + entity_id: str = event_data.get("entity_id", "") + + if not entity_id: + return + + # Apply ignore filter + if entity_id in self._ignore_entities: + return + + # Apply domain/entity watch filters + domain = entity_id.split(".")[0] if "." in entity_id else "" + if self._watch_domains or self._watch_entities: + domain_match = domain in self._watch_domains if self._watch_domains else False + entity_match = entity_id in self._watch_entities if self._watch_entities else False + if not domain_match and not entity_match: + return + + # Apply cooldown + now = time.time() + last = self._last_event_time.get(entity_id, 0) + if (now - last) < self._cooldown_seconds: + return + self._last_event_time[entity_id] = now + + # Build human-readable message + old_state = event_data.get("old_state", {}) + new_state = event_data.get("new_state", {}) + message = self._format_state_change(entity_id, old_state, new_state) + + if not message: + return + + # Build MessageEvent and forward to handler + source = self.build_source( + chat_id="ha_events", + chat_name="Home Assistant Events", + chat_type="channel", + user_id="homeassistant", + user_name="Home Assistant", + ) + + msg_event = MessageEvent( + text=message, + message_type=MessageType.TEXT, + source=source, + message_id=f"ha_{entity_id}_{int(now)}", + timestamp=datetime.now(), + ) + + await self.handle_message(msg_event) + + @staticmethod + def _format_state_change( + entity_id: str, + old_state: Dict[str, Any], + new_state: Dict[str, Any], + ) -> Optional[str]: + """Convert a state_changed event into a human-readable description.""" + if not new_state: + return None + + old_val = old_state.get("state", "unknown") if old_state else "unknown" + new_val = new_state.get("state", "unknown") + + # Skip if state didn't actually change + if old_val == new_val: + return None + + friendly_name = new_state.get("attributes", {}).get("friendly_name", entity_id) + domain = entity_id.split(".")[0] if "." in entity_id else "" + + # Domain-specific formatting + if domain == "climate": + attrs = new_state.get("attributes", {}) + temp = attrs.get("current_temperature", "?") + target = attrs.get("temperature", "?") + return ( + f"[Home Assistant] {friendly_name}: HVAC mode changed from " + f"'{old_val}' to '{new_val}' (current: {temp}, target: {target})" + ) + + if domain == "sensor": + unit = new_state.get("attributes", {}).get("unit_of_measurement", "") + return ( + f"[Home Assistant] {friendly_name}: changed from " + f"{old_val}{unit} to {new_val}{unit}" + ) + + if domain == "binary_sensor": + return ( + f"[Home Assistant] {friendly_name}: " + f"{'triggered' if new_val == 'on' else 'cleared'} " + f"(was {'triggered' if old_val == 'on' else 'cleared'})" + ) + + if domain in ("light", "switch", "fan"): + return ( + f"[Home Assistant] {friendly_name}: turned " + f"{'on' if new_val == 'on' else 'off'}" + ) + + if domain == "alarm_control_panel": + return ( + f"[Home Assistant] {friendly_name}: alarm state changed from " + f"'{old_val}' to '{new_val}'" + ) + + # Generic fallback + return ( + f"[Home Assistant] {friendly_name} ({entity_id}): " + f"changed from '{old_val}' to '{new_val}'" + ) + + # ------------------------------------------------------------------ + # Outbound messaging + # ------------------------------------------------------------------ + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a notification via HA REST API (persistent_notification.create). + + Uses the REST API instead of WebSocket to avoid a race condition + with the event listener loop that reads from the same WS connection. + """ + url = f"{self._hass_url}/api/services/persistent_notification/create" + headers = { + "Authorization": f"Bearer {self._hass_token}", + "Content-Type": "application/json", + } + payload = { + "title": "Hermes Agent", + "message": content[:self.MAX_MESSAGE_LENGTH], + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers=headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status < 300: + return SendResult(success=True, message_id=str(self._next_id())) + else: + body = await resp.text() + return SendResult(success=False, error=f"HTTP {resp.status}: {body}") + + except asyncio.TimeoutError: + return SendResult(success=False, error="Timeout sending notification to HA") + except Exception as e: + return SendResult(success=False, error=str(e)) + + async def send_typing(self, chat_id: str) -> None: + """No typing indicator for Home Assistant.""" + pass + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return basic info about the HA event channel.""" + return { + "name": "Home Assistant Events", + "type": "channel", + "url": self._hass_url, + } diff --git a/gateway/run.py b/gateway/run.py index bcd2457b..76ed3666 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -469,7 +469,14 @@ class GatewayRunner: logger.warning("Slack: slack-bolt not installed. Run: pip install 'hermes-agent[slack]'") return None return SlackAdapter(config) - + + elif platform == Platform.HOMEASSISTANT: + from gateway.platforms.homeassistant import HomeAssistantAdapter, check_ha_requirements + if not check_ha_requirements(): + logger.warning("HomeAssistant: aiohttp not installed or HASS_TOKEN not set") + return None + return HomeAssistantAdapter(config) + return None def _is_user_authorized(self, source: SessionSource) -> bool: diff --git a/model_tools.py b/model_tools.py index 036bb34b..38f01385 100644 --- a/model_tools.py +++ b/model_tools.py @@ -94,6 +94,7 @@ def _discover_tools(): "tools.process_registry", "tools.send_message_tool", "tools.honcho_tools", + "tools.homeassistant_tool", ] import importlib for mod_name in _modules: diff --git a/pyproject.toml b/pyproject.toml index 152b4730..a002f1bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ cli = ["simple-term-menu"] tts-premium = ["elevenlabs"] pty = ["ptyprocess>=0.7.0"] honcho = ["honcho-ai>=2.0.1"] +homeassistant = ["aiohttp>=3.9.0"] all = [ "hermes-agent[modal]", "hermes-agent[messaging]", @@ -57,6 +58,7 @@ all = [ "hermes-agent[slack]", "hermes-agent[pty]", "hermes-agent[honcho]", + "hermes-agent[homeassistant]", ] [project.scripts] diff --git a/tests/gateway/test_homeassistant.py b/tests/gateway/test_homeassistant.py new file mode 100644 index 00000000..f8bf7844 --- /dev/null +++ b/tests/gateway/test_homeassistant.py @@ -0,0 +1,604 @@ +"""Tests for the Home Assistant gateway adapter. + +Tests real logic: state change formatting, event filtering pipeline, +cooldown behavior, config integration, and adapter initialization. +""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import ( + GatewayConfig, + Platform, + PlatformConfig, +) +from gateway.platforms.homeassistant import ( + HomeAssistantAdapter, + check_ha_requirements, +) + + +# --------------------------------------------------------------------------- +# check_ha_requirements +# --------------------------------------------------------------------------- + + +class TestCheckRequirements: + def test_returns_false_without_token(self, monkeypatch): + monkeypatch.delenv("HASS_TOKEN", raising=False) + assert check_ha_requirements() is False + + def test_returns_true_with_token(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "test-token") + assert check_ha_requirements() is True + + @patch("gateway.platforms.homeassistant.AIOHTTP_AVAILABLE", False) + def test_returns_false_without_aiohttp(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "test-token") + assert check_ha_requirements() is False + + +# --------------------------------------------------------------------------- +# _format_state_change - pure function, all domain branches +# --------------------------------------------------------------------------- + + +class TestFormatStateChange: + @staticmethod + def fmt(entity_id, old_state, new_state): + return HomeAssistantAdapter._format_state_change(entity_id, old_state, new_state) + + def test_climate_includes_temperatures(self): + msg = self.fmt( + "climate.thermostat", + {"state": "off"}, + {"state": "heat", "attributes": { + "friendly_name": "Main Thermostat", + "current_temperature": 21.5, + "temperature": 23, + }}, + ) + assert "Main Thermostat" in msg + assert "'off'" in msg and "'heat'" in msg + assert "21.5" in msg and "23" in msg + + def test_sensor_includes_unit(self): + msg = self.fmt( + "sensor.temperature", + {"state": "22.5"}, + {"state": "25.1", "attributes": { + "friendly_name": "Living Room Temp", + "unit_of_measurement": "C", + }}, + ) + assert "22.5C" in msg and "25.1C" in msg + assert "Living Room Temp" in msg + + def test_sensor_without_unit(self): + msg = self.fmt( + "sensor.count", + {"state": "5"}, + {"state": "10", "attributes": {"friendly_name": "Counter"}}, + ) + assert "5" in msg and "10" in msg + + def test_binary_sensor_on(self): + msg = self.fmt( + "binary_sensor.motion", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Hallway Motion"}}, + ) + assert "triggered" in msg + assert "Hallway Motion" in msg + + def test_binary_sensor_off(self): + msg = self.fmt( + "binary_sensor.door", + {"state": "on"}, + {"state": "off", "attributes": {"friendly_name": "Front Door"}}, + ) + assert "cleared" in msg + + def test_light_turned_on(self): + msg = self.fmt( + "light.bedroom", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Bedroom Light"}}, + ) + assert "turned on" in msg + + def test_switch_turned_off(self): + msg = self.fmt( + "switch.heater", + {"state": "on"}, + {"state": "off", "attributes": {"friendly_name": "Heater"}}, + ) + assert "turned off" in msg + + def test_fan_domain_uses_light_switch_branch(self): + msg = self.fmt( + "fan.ceiling", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Ceiling Fan"}}, + ) + assert "turned on" in msg + + def test_alarm_panel(self): + msg = self.fmt( + "alarm_control_panel.home", + {"state": "disarmed"}, + {"state": "armed_away", "attributes": {"friendly_name": "Home Alarm"}}, + ) + assert "Home Alarm" in msg + assert "armed_away" in msg and "disarmed" in msg + + def test_generic_domain_includes_entity_id(self): + msg = self.fmt( + "automation.morning", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Morning Routine"}}, + ) + assert "automation.morning" in msg + assert "Morning Routine" in msg + + def test_same_state_returns_none(self): + assert self.fmt( + "sensor.temp", + {"state": "22"}, + {"state": "22", "attributes": {"friendly_name": "Temp"}}, + ) is None + + def test_empty_new_state_returns_none(self): + assert self.fmt("light.x", {"state": "on"}, {}) is None + + def test_no_old_state_uses_unknown(self): + msg = self.fmt( + "light.new", + None, + {"state": "on", "attributes": {"friendly_name": "New Light"}}, + ) + assert msg is not None + assert "New Light" in msg + + def test_uses_entity_id_when_no_friendly_name(self): + msg = self.fmt( + "sensor.unnamed", + {"state": "1"}, + {"state": "2", "attributes": {}}, + ) + assert "sensor.unnamed" in msg + + +# --------------------------------------------------------------------------- +# Adapter initialization from config +# --------------------------------------------------------------------------- + + +class TestAdapterInit: + def test_url_and_token_from_config_extra(self, monkeypatch): + monkeypatch.delenv("HASS_URL", raising=False) + monkeypatch.delenv("HASS_TOKEN", raising=False) + + config = PlatformConfig( + enabled=True, + token="config-token", + extra={"url": "http://192.168.1.50:8123"}, + ) + adapter = HomeAssistantAdapter(config) + assert adapter._hass_token == "config-token" + assert adapter._hass_url == "http://192.168.1.50:8123" + + def test_url_fallback_to_env(self, monkeypatch): + monkeypatch.setenv("HASS_URL", "http://env-host:8123") + monkeypatch.setenv("HASS_TOKEN", "env-tok") + + config = PlatformConfig(enabled=True, token="env-tok") + adapter = HomeAssistantAdapter(config) + assert adapter._hass_url == "http://env-host:8123" + + def test_trailing_slash_stripped(self): + config = PlatformConfig( + enabled=True, token="t", + extra={"url": "http://ha.local:8123/"}, + ) + adapter = HomeAssistantAdapter(config) + assert adapter._hass_url == "http://ha.local:8123" + + def test_watch_filters_parsed(self): + config = PlatformConfig( + enabled=True, token="t", + extra={ + "watch_domains": ["climate", "binary_sensor"], + "watch_entities": ["sensor.special"], + "ignore_entities": ["sensor.uptime", "sensor.cpu"], + "cooldown_seconds": 120, + }, + ) + adapter = HomeAssistantAdapter(config) + assert adapter._watch_domains == {"climate", "binary_sensor"} + assert adapter._watch_entities == {"sensor.special"} + assert adapter._ignore_entities == {"sensor.uptime", "sensor.cpu"} + assert adapter._cooldown_seconds == 120 + + def test_defaults_when_no_extra(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "tok") + config = PlatformConfig(enabled=True, token="tok") + adapter = HomeAssistantAdapter(config) + assert adapter._watch_domains == set() + assert adapter._watch_entities == set() + assert adapter._ignore_entities == set() + assert adapter._cooldown_seconds == 30 + + +# --------------------------------------------------------------------------- +# Event filtering pipeline (_handle_ha_event) +# +# We mock handle_message (not our code, it's the base class pipeline) to +# capture the MessageEvent that _handle_ha_event produces. +# --------------------------------------------------------------------------- + + +def _make_adapter(**extra) -> HomeAssistantAdapter: + config = PlatformConfig(enabled=True, token="tok", extra=extra) + adapter = HomeAssistantAdapter(config) + adapter.handle_message = AsyncMock() + return adapter + + +def _make_event(entity_id, old_state, new_state, old_attrs=None, new_attrs=None): + return { + "data": { + "entity_id": entity_id, + "old_state": {"state": old_state, "attributes": old_attrs or {}}, + "new_state": {"state": new_state, "attributes": new_attrs or {"friendly_name": entity_id}}, + } + } + + +class TestEventFilteringPipeline: + @pytest.mark.asyncio + async def test_ignored_entity_not_forwarded(self): + adapter = _make_adapter(ignore_entities=["sensor.uptime"]) + await adapter._handle_ha_event(_make_event("sensor.uptime", "100", "101")) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_unwatched_domain_not_forwarded(self): + adapter = _make_adapter(watch_domains=["climate"]) + await adapter._handle_ha_event(_make_event("light.bedroom", "off", "on")) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_watched_domain_forwarded(self): + adapter = _make_adapter(watch_domains=["climate"], cooldown_seconds=0) + await adapter._handle_ha_event( + _make_event("climate.thermostat", "off", "heat", + new_attrs={"friendly_name": "Thermostat", "current_temperature": 20, "temperature": 22}) + ) + adapter.handle_message.assert_called_once() + + # Verify the actual MessageEvent text content + msg_event = adapter.handle_message.call_args[0][0] + assert "Thermostat" in msg_event.text + assert "heat" in msg_event.text + assert msg_event.source.platform == Platform.HOMEASSISTANT + assert msg_event.source.chat_id == "ha_events" + + @pytest.mark.asyncio + async def test_watched_entity_forwarded(self): + adapter = _make_adapter(watch_entities=["sensor.important"], cooldown_seconds=0) + await adapter._handle_ha_event( + _make_event("sensor.important", "10", "20", + new_attrs={"friendly_name": "Important Sensor", "unit_of_measurement": "W"}) + ) + adapter.handle_message.assert_called_once() + msg_event = adapter.handle_message.call_args[0][0] + assert "10W" in msg_event.text and "20W" in msg_event.text + + @pytest.mark.asyncio + async def test_no_filters_passes_everything(self): + adapter = _make_adapter(cooldown_seconds=0) + await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open")) + adapter.handle_message.assert_called_once() + + @pytest.mark.asyncio + async def test_same_state_not_forwarded(self): + adapter = _make_adapter(cooldown_seconds=0) + await adapter._handle_ha_event(_make_event("light.x", "on", "on")) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_empty_entity_id_skipped(self): + adapter = _make_adapter() + await adapter._handle_ha_event({"data": {"entity_id": ""}}) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_message_event_has_correct_source(self): + adapter = _make_adapter(cooldown_seconds=0) + await adapter._handle_ha_event( + _make_event("light.test", "off", "on", + new_attrs={"friendly_name": "Test Light"}) + ) + msg_event = adapter.handle_message.call_args[0][0] + assert msg_event.source.user_name == "Home Assistant" + assert msg_event.source.chat_type == "channel" + assert msg_event.message_id.startswith("ha_light.test_") + + +# --------------------------------------------------------------------------- +# Cooldown behavior +# --------------------------------------------------------------------------- + + +class TestCooldown: + @pytest.mark.asyncio + async def test_cooldown_blocks_rapid_events(self): + adapter = _make_adapter(cooldown_seconds=60) + + event = _make_event("sensor.temp", "20", "21", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event) + assert adapter.handle_message.call_count == 1 + + # Second event immediately after should be blocked + event2 = _make_event("sensor.temp", "21", "22", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event2) + assert adapter.handle_message.call_count == 1 # Still 1 + + @pytest.mark.asyncio + async def test_cooldown_expires(self): + adapter = _make_adapter(cooldown_seconds=1) + + event = _make_event("sensor.temp", "20", "21", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event) + assert adapter.handle_message.call_count == 1 + + # Simulate time passing beyond cooldown + adapter._last_event_time["sensor.temp"] = time.time() - 2 + + event2 = _make_event("sensor.temp", "21", "22", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event2) + assert adapter.handle_message.call_count == 2 + + @pytest.mark.asyncio + async def test_different_entities_independent_cooldowns(self): + adapter = _make_adapter(cooldown_seconds=60) + + await adapter._handle_ha_event( + _make_event("sensor.a", "1", "2", new_attrs={"friendly_name": "A"}) + ) + await adapter._handle_ha_event( + _make_event("sensor.b", "3", "4", new_attrs={"friendly_name": "B"}) + ) + # Both should pass - different entities + assert adapter.handle_message.call_count == 2 + + # Same entity again - should be blocked + await adapter._handle_ha_event( + _make_event("sensor.a", "2", "3", new_attrs={"friendly_name": "A"}) + ) + assert adapter.handle_message.call_count == 2 # Still 2 + + @pytest.mark.asyncio + async def test_zero_cooldown_passes_all(self): + adapter = _make_adapter(cooldown_seconds=0) + + for i in range(5): + await adapter._handle_ha_event( + _make_event("sensor.temp", str(i), str(i + 1), + new_attrs={"friendly_name": "Temp"}) + ) + assert adapter.handle_message.call_count == 5 + + +# --------------------------------------------------------------------------- +# Config integration (env overrides, round-trip) +# --------------------------------------------------------------------------- + + +class TestConfigIntegration: + def test_env_override_creates_ha_platform(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "env-token") + monkeypatch.setenv("HASS_URL", "http://10.0.0.5:8123") + # Clear other platform tokens + for v in ["TELEGRAM_BOT_TOKEN", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]: + monkeypatch.delenv(v, raising=False) + + from gateway.config import load_gateway_config + config = load_gateway_config() + + assert Platform.HOMEASSISTANT in config.platforms + ha = config.platforms[Platform.HOMEASSISTANT] + assert ha.enabled is True + assert ha.token == "env-token" + assert ha.extra["url"] == "http://10.0.0.5:8123" + + def test_no_env_no_platform(self, monkeypatch): + for v in ["HASS_TOKEN", "HASS_URL", "TELEGRAM_BOT_TOKEN", + "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]: + monkeypatch.delenv(v, raising=False) + + from gateway.config import load_gateway_config + config = load_gateway_config() + assert Platform.HOMEASSISTANT not in config.platforms + + def test_config_roundtrip_preserves_extra(self): + config = GatewayConfig( + platforms={ + Platform.HOMEASSISTANT: PlatformConfig( + enabled=True, + token="tok", + extra={ + "url": "http://ha:8123", + "watch_domains": ["climate"], + "cooldown_seconds": 45, + }, + ), + }, + ) + d = config.to_dict() + restored = GatewayConfig.from_dict(d) + + ha = restored.platforms[Platform.HOMEASSISTANT] + assert ha.enabled is True + assert ha.token == "tok" + assert ha.extra["watch_domains"] == ["climate"] + assert ha.extra["cooldown_seconds"] == 45 + + def test_connected_platforms_includes_ha(self): + config = GatewayConfig( + platforms={ + Platform.HOMEASSISTANT: PlatformConfig(enabled=True, token="tok"), + Platform.TELEGRAM: PlatformConfig(enabled=False, token="t"), + }, + ) + connected = config.get_connected_platforms() + assert Platform.HOMEASSISTANT in connected + assert Platform.TELEGRAM not in connected + + +# --------------------------------------------------------------------------- +# send() via REST API +# --------------------------------------------------------------------------- + + +class TestSendViaRestApi: + """send() uses REST API (not WebSocket) to avoid race conditions.""" + + @staticmethod + def _mock_aiohttp_session(response_status=200, response_text="OK"): + """Build a mock aiohttp session + response for async-with patterns. + + aiohttp.ClientSession() is a sync constructor whose return value + is used as ``async with session:``. ``session.post(...)`` returns a + context-manager (not a coroutine), so both layers use MagicMock for + the call and AsyncMock only for ``__aenter__`` / ``__aexit__``. + """ + mock_response = MagicMock() + mock_response.status = response_status + mock_response.text = AsyncMock(return_value=response_text) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + return mock_session + + @pytest.mark.asyncio + async def test_send_success(self): + adapter = _make_adapter() + mock_session = self._mock_aiohttp_session(200) + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + result = await adapter.send("ha_events", "Test notification") + + assert result.success is True + # Verify the REST API was called with correct payload + call_args = mock_session.post.call_args + assert "/api/services/persistent_notification/create" in call_args[0][0] + assert call_args[1]["json"]["title"] == "Hermes Agent" + assert call_args[1]["json"]["message"] == "Test notification" + assert "Bearer tok" in call_args[1]["headers"]["Authorization"] + + @pytest.mark.asyncio + async def test_send_http_error(self): + adapter = _make_adapter() + mock_session = self._mock_aiohttp_session(401, "Unauthorized") + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + result = await adapter.send("ha_events", "Test") + + assert result.success is False + assert "401" in result.error + + @pytest.mark.asyncio + async def test_send_truncates_long_message(self): + adapter = _make_adapter() + mock_session = self._mock_aiohttp_session(200) + long_message = "x" * 10000 + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + await adapter.send("ha_events", long_message) + + sent_message = mock_session.post.call_args[1]["json"]["message"] + assert len(sent_message) == 4096 + + @pytest.mark.asyncio + async def test_send_does_not_use_websocket(self): + """send() must use REST API, not the WS connection (race condition fix).""" + adapter = _make_adapter() + adapter._ws = AsyncMock() # Simulate an active WS + mock_session = self._mock_aiohttp_session(200) + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + await adapter.send("ha_events", "Test") + + # WS should NOT have been used for sending + adapter._ws.send_json.assert_not_called() + adapter._ws.receive_json.assert_not_called() + + +# --------------------------------------------------------------------------- +# Toolset integration +# --------------------------------------------------------------------------- + + +class TestToolsetIntegration: + def test_homeassistant_toolset_resolves(self): + from toolsets import resolve_toolset + + tools = resolve_toolset("homeassistant") + assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service"} + + def test_gateway_toolset_includes_ha_tools(self): + from toolsets import resolve_toolset + + gateway_tools = resolve_toolset("hermes-gateway") + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + assert tool in gateway_tools + + def test_hermes_core_tools_includes_ha(self): + from toolsets import _HERMES_CORE_TOOLS + + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + assert tool in _HERMES_CORE_TOOLS + + +# --------------------------------------------------------------------------- +# WebSocket URL construction +# --------------------------------------------------------------------------- + + +class TestWsUrlConstruction: + def test_http_to_ws(self): + config = PlatformConfig(enabled=True, token="t", extra={"url": "http://ha:8123"}) + adapter = HomeAssistantAdapter(config) + ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://") + assert ws_url == "ws://ha:8123" + + def test_https_to_wss(self): + config = PlatformConfig(enabled=True, token="t", extra={"url": "https://ha.example.com"}) + adapter = HomeAssistantAdapter(config) + ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://") + assert ws_url == "wss://ha.example.com" diff --git a/tests/tools/test_homeassistant_tool.py b/tests/tools/test_homeassistant_tool.py new file mode 100644 index 00000000..6235474e --- /dev/null +++ b/tests/tools/test_homeassistant_tool.py @@ -0,0 +1,281 @@ +"""Tests for the Home Assistant tool module. + +Tests real logic: entity filtering, payload building, response parsing, +handler validation, and availability gating. +""" + +import json + +import pytest + +from tools.homeassistant_tool import ( + _check_ha_available, + _filter_and_summarize, + _build_service_payload, + _parse_service_response, + _get_headers, + _handle_get_state, + _handle_call_service, +) + + +# --------------------------------------------------------------------------- +# Sample HA state data (matches real HA /api/states response shape) +# --------------------------------------------------------------------------- + +SAMPLE_STATES = [ + {"entity_id": "light.bedroom", "state": "on", "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}}, + {"entity_id": "light.kitchen", "state": "off", "attributes": {"friendly_name": "Kitchen Light"}}, + {"entity_id": "switch.fan", "state": "on", "attributes": {"friendly_name": "Living Room Fan"}}, + {"entity_id": "sensor.temperature", "state": "22.5", "attributes": {"friendly_name": "Kitchen Temperature", "unit_of_measurement": "C"}}, + {"entity_id": "climate.thermostat", "state": "heat", "attributes": {"friendly_name": "Main Thermostat", "current_temperature": 21}}, + {"entity_id": "binary_sensor.motion", "state": "off", "attributes": {"friendly_name": "Hallway Motion"}}, + {"entity_id": "sensor.humidity", "state": "55", "attributes": {"friendly_name": "Bedroom Humidity", "area": "bedroom"}}, +] + + +# --------------------------------------------------------------------------- +# Entity filtering and summarization +# --------------------------------------------------------------------------- + + +class TestFilterAndSummarize: + def test_no_filters_returns_all(self): + result = _filter_and_summarize(SAMPLE_STATES) + assert result["count"] == 7 + ids = {e["entity_id"] for e in result["entities"]} + assert "light.bedroom" in ids + assert "climate.thermostat" in ids + + def test_domain_filter_lights(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="light") + assert result["count"] == 2 + for e in result["entities"]: + assert e["entity_id"].startswith("light.") + + def test_domain_filter_sensor(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="sensor") + assert result["count"] == 2 + ids = {e["entity_id"] for e in result["entities"]} + assert ids == {"sensor.temperature", "sensor.humidity"} + + def test_domain_filter_no_matches(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="media_player") + assert result["count"] == 0 + assert result["entities"] == [] + + def test_area_filter_by_friendly_name(self): + result = _filter_and_summarize(SAMPLE_STATES, area="kitchen") + assert result["count"] == 2 + ids = {e["entity_id"] for e in result["entities"]} + assert "light.kitchen" in ids + assert "sensor.temperature" in ids + + def test_area_filter_by_area_attribute(self): + result = _filter_and_summarize(SAMPLE_STATES, area="bedroom") + ids = {e["entity_id"] for e in result["entities"]} + # "Bedroom Light" matches via friendly_name, "Bedroom Humidity" matches via area attr + assert "light.bedroom" in ids + assert "sensor.humidity" in ids + + def test_area_filter_case_insensitive(self): + result = _filter_and_summarize(SAMPLE_STATES, area="KITCHEN") + assert result["count"] == 2 + + def test_combined_domain_and_area(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="sensor", area="kitchen") + assert result["count"] == 1 + assert result["entities"][0]["entity_id"] == "sensor.temperature" + + def test_summary_includes_friendly_name(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="climate") + assert result["entities"][0]["friendly_name"] == "Main Thermostat" + assert result["entities"][0]["state"] == "heat" + + def test_empty_states_list(self): + result = _filter_and_summarize([]) + assert result["count"] == 0 + + def test_missing_attributes_handled(self): + states = [{"entity_id": "light.x", "state": "on"}] + result = _filter_and_summarize(states) + assert result["count"] == 1 + assert result["entities"][0]["friendly_name"] == "" + + +# --------------------------------------------------------------------------- +# Service payload building +# --------------------------------------------------------------------------- + + +class TestBuildServicePayload: + def test_entity_id_only(self): + payload = _build_service_payload(entity_id="light.bedroom") + assert payload == {"entity_id": "light.bedroom"} + + def test_data_only(self): + payload = _build_service_payload(data={"brightness": 255}) + assert payload == {"brightness": 255} + + def test_entity_id_and_data(self): + payload = _build_service_payload( + entity_id="light.bedroom", + data={"brightness": 200, "color_name": "blue"}, + ) + assert payload["entity_id"] == "light.bedroom" + assert payload["brightness"] == 200 + assert payload["color_name"] == "blue" + + def test_no_args_returns_empty(self): + payload = _build_service_payload() + assert payload == {} + + def test_data_does_not_overwrite_entity_id(self): + payload = _build_service_payload( + entity_id="light.a", + data={"entity_id": "light.b"}, + ) + # data.update overwrites entity_id set earlier + assert payload["entity_id"] == "light.b" + + +# --------------------------------------------------------------------------- +# Service response parsing +# --------------------------------------------------------------------------- + + +class TestParseServiceResponse: + def test_list_response_extracts_entities(self): + ha_response = [ + {"entity_id": "light.bedroom", "state": "on", "attributes": {}}, + {"entity_id": "light.kitchen", "state": "on", "attributes": {}}, + ] + result = _parse_service_response("light", "turn_on", ha_response) + assert result["success"] is True + assert result["service"] == "light.turn_on" + assert len(result["affected_entities"]) == 2 + assert result["affected_entities"][0]["entity_id"] == "light.bedroom" + + def test_empty_list_response(self): + result = _parse_service_response("scene", "turn_on", []) + assert result["success"] is True + assert result["affected_entities"] == [] + + def test_non_list_response(self): + # Some HA services return a dict instead of a list + result = _parse_service_response("script", "run", {"result": "ok"}) + assert result["success"] is True + assert result["affected_entities"] == [] + + def test_none_response(self): + result = _parse_service_response("automation", "trigger", None) + assert result["success"] is True + assert result["affected_entities"] == [] + + def test_service_name_format(self): + result = _parse_service_response("climate", "set_temperature", []) + assert result["service"] == "climate.set_temperature" + + +# --------------------------------------------------------------------------- +# Handler validation (no mocks - these paths don't reach the network) +# --------------------------------------------------------------------------- + + +class TestHandlerValidation: + def test_get_state_missing_entity_id(self): + result = json.loads(_handle_get_state({})) + assert "error" in result + assert "entity_id" in result["error"] + + def test_get_state_empty_entity_id(self): + result = json.loads(_handle_get_state({"entity_id": ""})) + assert "error" in result + + def test_call_service_missing_domain(self): + result = json.loads(_handle_call_service({"service": "turn_on"})) + assert "error" in result + assert "domain" in result["error"] + + def test_call_service_missing_service(self): + result = json.loads(_handle_call_service({"domain": "light"})) + assert "error" in result + assert "service" in result["error"] + + def test_call_service_missing_both(self): + result = json.loads(_handle_call_service({})) + assert "error" in result + + def test_call_service_empty_strings(self): + result = json.loads(_handle_call_service({"domain": "", "service": ""})) + assert "error" in result + + +# --------------------------------------------------------------------------- +# Availability check +# --------------------------------------------------------------------------- + + +class TestCheckAvailable: + def test_unavailable_without_token(self, monkeypatch): + monkeypatch.delenv("HASS_TOKEN", raising=False) + assert _check_ha_available() is False + + def test_available_with_token(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "eyJ0eXAiOiJKV1Q") + assert _check_ha_available() is True + + def test_empty_token_is_unavailable(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "") + assert _check_ha_available() is False + + +# --------------------------------------------------------------------------- +# Auth headers +# --------------------------------------------------------------------------- + + +class TestGetHeaders: + def test_bearer_token_format(self, monkeypatch): + monkeypatch.setattr("tools.homeassistant_tool._HASS_TOKEN", "my-secret-token") + headers = _get_headers() + assert headers["Authorization"] == "Bearer my-secret-token" + assert headers["Content-Type"] == "application/json" + + +# --------------------------------------------------------------------------- +# Registry integration +# --------------------------------------------------------------------------- + + +class TestRegistration: + def test_tools_registered_in_registry(self): + from tools.registry import registry + + names = registry.get_all_tool_names() + assert "ha_list_entities" in names + assert "ha_get_state" in names + assert "ha_call_service" in names + + def test_tools_in_homeassistant_toolset(self): + from tools.registry import registry + + toolset_map = registry.get_tool_to_toolset_map() + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + assert toolset_map[tool] == "homeassistant" + + def test_check_fn_gates_availability(self, monkeypatch): + """Registry should exclude HA tools when HASS_TOKEN is not set.""" + from tools.registry import registry + + monkeypatch.delenv("HASS_TOKEN", raising=False) + defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"}) + assert len(defs) == 0 + + def test_check_fn_includes_when_token_set(self, monkeypatch): + """Registry should include HA tools when HASS_TOKEN is set.""" + from tools.registry import registry + + monkeypatch.setenv("HASS_TOKEN", "test-token") + defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"}) + assert len(defs) == 3 diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py new file mode 100644 index 00000000..4a01382f --- /dev/null +++ b/tools/homeassistant_tool.py @@ -0,0 +1,364 @@ +"""Home Assistant tool for controlling smart home devices via REST API. + +Registers three LLM-callable tools: +- ``ha_list_entities`` -- list/filter entities by domain or area +- ``ha_get_state`` -- get detailed state of a single entity +- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.) + +Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var. +The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123). +""" + +import asyncio +import json +import logging +import os +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/") +_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "") + + +def _get_headers() -> Dict[str, str]: + """Return authorization headers for HA REST API.""" + return { + "Authorization": f"Bearer {_HASS_TOKEN}", + "Content-Type": "application/json", + } + + +# --------------------------------------------------------------------------- +# Async helpers (called from sync handlers via run_until_complete) +# --------------------------------------------------------------------------- + +def _filter_and_summarize( + states: list, + domain: Optional[str] = None, + area: Optional[str] = None, +) -> Dict[str, Any]: + """Filter raw HA states by domain/area and return a compact summary.""" + if domain: + states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")] + + if area: + area_lower = area.lower() + states = [ + s for s in states + if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower() + or area_lower in (s.get("attributes", {}).get("area", "") or "").lower() + ] + + entities = [] + for s in states: + entities.append({ + "entity_id": s["entity_id"], + "state": s["state"], + "friendly_name": s.get("attributes", {}).get("friendly_name", ""), + }) + + return {"count": len(entities), "entities": entities} + + +async def _async_list_entities( + domain: Optional[str] = None, + area: Optional[str] = None, +) -> Dict[str, Any]: + """Fetch entity states from HA and optionally filter by domain/area.""" + import aiohttp + + url = f"{_HASS_URL}/api/states" + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp: + resp.raise_for_status() + states = await resp.json() + + return _filter_and_summarize(states, domain, area) + + +async def _async_get_state(entity_id: str) -> Dict[str, Any]: + """Fetch detailed state of a single entity.""" + import aiohttp + + url = f"{_HASS_URL}/api/states/{entity_id}" + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp: + resp.raise_for_status() + data = await resp.json() + + return { + "entity_id": data["entity_id"], + "state": data["state"], + "attributes": data.get("attributes", {}), + "last_changed": data.get("last_changed"), + "last_updated": data.get("last_updated"), + } + + +def _build_service_payload( + entity_id: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Build the JSON payload for a HA service call.""" + payload: Dict[str, Any] = {} + if entity_id: + payload["entity_id"] = entity_id + if data: + payload.update(data) + return payload + + +def _parse_service_response( + domain: str, + service: str, + result: Any, +) -> Dict[str, Any]: + """Parse HA service call response into a structured result.""" + affected = [] + if isinstance(result, list): + for s in result: + affected.append({ + "entity_id": s.get("entity_id", ""), + "state": s.get("state", ""), + }) + + return { + "success": True, + "service": f"{domain}.{service}", + "affected_entities": affected, + } + + +async def _async_call_service( + domain: str, + service: str, + entity_id: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Call a Home Assistant service.""" + import aiohttp + + url = f"{_HASS_URL}/api/services/{domain}/{service}" + payload = _build_service_payload(entity_id, data) + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers=_get_headers(), + json=payload, + timeout=aiohttp.ClientTimeout(total=15), + ) as resp: + resp.raise_for_status() + result = await resp.json() + + return _parse_service_response(domain, service, result) + + +# --------------------------------------------------------------------------- +# Sync wrappers (handler signature: (args, **kw) -> str) +# --------------------------------------------------------------------------- + +def _run_async(coro): + """Run an async coroutine from a sync handler.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # Already inside an event loop -- create a new thread + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro) + return future.result(timeout=30) + else: + return asyncio.run(coro) + + +def _handle_list_entities(args: dict, **kw) -> str: + """Handler for ha_list_entities tool.""" + domain = args.get("domain") + area = args.get("area") + try: + result = _run_async(_async_list_entities(domain=domain, area=area)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_list_entities error: %s", e) + return json.dumps({"error": f"Failed to list entities: {e}"}) + + +def _handle_get_state(args: dict, **kw) -> str: + """Handler for ha_get_state tool.""" + entity_id = args.get("entity_id", "") + if not entity_id: + return json.dumps({"error": "Missing required parameter: entity_id"}) + try: + result = _run_async(_async_get_state(entity_id)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_get_state error: %s", e) + return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"}) + + +def _handle_call_service(args: dict, **kw) -> str: + """Handler for ha_call_service tool.""" + domain = args.get("domain", "") + service = args.get("service", "") + if not domain or not service: + return json.dumps({"error": "Missing required parameters: domain and service"}) + + entity_id = args.get("entity_id") + data = args.get("data") + try: + result = _run_async(_async_call_service(domain, service, entity_id, data)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_call_service error: %s", e) + return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"}) + + +# --------------------------------------------------------------------------- +# Availability check +# --------------------------------------------------------------------------- + +def _check_ha_available() -> bool: + """Tool is only available when HASS_TOKEN is set.""" + return bool(os.getenv("HASS_TOKEN")) + + +# --------------------------------------------------------------------------- +# Tool schemas +# --------------------------------------------------------------------------- + +HA_LIST_ENTITIES_SCHEMA = { + "name": "ha_list_entities", + "description": ( + "List Home Assistant entities. Optionally filter by domain " + "(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) " + "or by area name (living room, kitchen, bedroom, etc.)." + ), + "parameters": { + "type": "object", + "properties": { + "domain": { + "type": "string", + "description": ( + "Entity domain to filter by (e.g. 'light', 'switch', 'climate', " + "'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). " + "Omit to list all entities." + ), + }, + "area": { + "type": "string", + "description": ( + "Area/room name to filter by (e.g. 'living room', 'kitchen'). " + "Matches against entity friendly names. Omit to list all." + ), + }, + }, + "required": [], + }, +} + +HA_GET_STATE_SCHEMA = { + "name": "ha_get_state", + "description": ( + "Get the detailed state of a single Home Assistant entity, including all " + "attributes (brightness, color, temperature setpoint, sensor readings, etc.)." + ), + "parameters": { + "type": "object", + "properties": { + "entity_id": { + "type": "string", + "description": ( + "The entity ID to query (e.g. 'light.living_room', " + "'climate.thermostat', 'sensor.temperature')." + ), + }, + }, + "required": ["entity_id"], + }, +} + +HA_CALL_SERVICE_SCHEMA = { + "name": "ha_call_service", + "description": ( + "Call a Home Assistant service to control a device. Common examples: " + "turn_on/turn_off lights and switches, set_temperature for climate, " + "open_cover/close_cover for blinds, set_volume_level for media players." + ), + "parameters": { + "type": "object", + "properties": { + "domain": { + "type": "string", + "description": ( + "Service domain (e.g. 'light', 'switch', 'climate', " + "'cover', 'media_player', 'fan', 'scene', 'script')." + ), + }, + "service": { + "type": "string", + "description": ( + "Service name (e.g. 'turn_on', 'turn_off', 'toggle', " + "'set_temperature', 'set_hvac_mode', 'open_cover', " + "'close_cover', 'set_volume_level')." + ), + }, + "entity_id": { + "type": "string", + "description": ( + "Target entity ID (e.g. 'light.living_room'). " + "Some services (like scene.turn_on) may not need this." + ), + }, + "data": { + "type": "object", + "description": ( + "Additional service data. Examples: " + '{"brightness": 255, "color_name": "blue"} for lights, ' + '{"temperature": 22, "hvac_mode": "heat"} for climate, ' + '{"volume_level": 0.5} for media players.' + ), + }, + }, + "required": ["domain", "service"], + }, +} + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + +from tools.registry import registry + +registry.register( + name="ha_list_entities", + toolset="homeassistant", + schema=HA_LIST_ENTITIES_SCHEMA, + handler=_handle_list_entities, + check_fn=_check_ha_available, +) + +registry.register( + name="ha_get_state", + toolset="homeassistant", + schema=HA_GET_STATE_SCHEMA, + handler=_handle_get_state, + check_fn=_check_ha_available, +) + +registry.register( + name="ha_call_service", + toolset="homeassistant", + schema=HA_CALL_SERVICE_SCHEMA, + handler=_handle_call_service, + check_fn=_check_ha_available, +) diff --git a/toolsets.py b/toolsets.py index 6090068a..44b81449 100644 --- a/toolsets.py +++ b/toolsets.py @@ -62,6 +62,8 @@ _HERMES_CORE_TOOLS = [ "send_message", # Honcho user context (gated on honcho being active via check_fn) "query_user_context", + # Home Assistant smart home control (gated on HASS_TOKEN via check_fn) + "ha_list_entities", "ha_get_state", "ha_call_service", ] @@ -193,8 +195,14 @@ TOOLSETS = { "tools": ["query_user_context"], "includes": [] }, - - + + "homeassistant": { + "description": "Home Assistant smart home control and monitoring", + "tools": ["ha_list_entities", "ha_get_state", "ha_call_service"], + "includes": [] + }, + + # Scenario-specific toolsets "debugging": { @@ -247,10 +255,16 @@ TOOLSETS = { "includes": [] }, + "hermes-homeassistant": { + "description": "Home Assistant bot toolset - smart home event monitoring and control", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-homeassistant"] } } diff --git a/uv.lock b/uv.lock index 54863389..5e3bd5f7 100644 --- a/uv.lock +++ b/uv.lock @@ -1034,6 +1034,9 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, ] +homeassistant = [ + { name = "aiohttp" }, +] honcho = [ { name = "honcho-ai" }, ] @@ -1060,6 +1063,7 @@ tts-premium = [ [package.metadata] requires-dist = [ + { name = "aiohttp", marker = "extra == 'homeassistant'", specifier = ">=3.9.0" }, { name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.9.0" }, { name = "croniter", marker = "extra == 'cron'" }, { name = "discord-py", marker = "extra == 'messaging'", specifier = ">=2.0" }, @@ -1071,6 +1075,7 @@ requires-dist = [ { name = "hermes-agent", extras = ["cli"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["cron"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["dev"], marker = "extra == 'all'" }, + { name = "hermes-agent", extras = ["homeassistant"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["honcho"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["messaging"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["modal"], marker = "extra == 'all'" }, @@ -1103,7 +1108,7 @@ requires-dist = [ { name = "tenacity" }, { name = "typer" }, ] -provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "all"] +provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "homeassistant", "all"] [[package]] name = "hf-xet" From b32c642af3cfd8c2fee700e1c05fad10fca07a0e Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Sat, 28 Feb 2026 14:28:04 +0300 Subject: [PATCH 2/5] test: add HA integration tests with fake in-process server Fake HA server (aiohttp.web) simulates full API surface over real TCP: - WebSocket auth handshake + event push - REST endpoints (states, services, notifications) 14 integration tests verify end-to-end flows without mocks: - WS connect/auth/subscribe/event-forwarding/disconnect - REST list/get/call-service against fake server - send() notification delivery and auth failure - 401/500 error handling --- tests/fakes/__init__.py | 0 tests/fakes/fake_ha_server.py | 288 +++++++++++++++++++ tests/integration/test_ha_integration.py | 341 +++++++++++++++++++++++ 3 files changed, 629 insertions(+) create mode 100644 tests/fakes/__init__.py create mode 100644 tests/fakes/fake_ha_server.py create mode 100644 tests/integration/test_ha_integration.py diff --git a/tests/fakes/__init__.py b/tests/fakes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fakes/fake_ha_server.py b/tests/fakes/fake_ha_server.py new file mode 100644 index 00000000..1d51bf51 --- /dev/null +++ b/tests/fakes/fake_ha_server.py @@ -0,0 +1,288 @@ +"""Fake Home Assistant server for integration testing. + +Provides a real HTTP + WebSocket server (via aiohttp.web) that mimics the +Home Assistant API surface used by hermes-agent: + +- ``/api/websocket`` -- WebSocket auth handshake + event push +- ``/api/states`` -- GET all entity states +- ``/api/states/{entity_id}`` -- GET single entity state +- ``/api/services/{domain}/{service}`` -- POST service call +- ``/api/services/persistent_notification/create`` -- POST notification + +Usage:: + + async with FakeHAServer(token="test-token") as server: + url = server.url # e.g. "http://127.0.0.1:54321" + await server.push_event(event_data) + assert server.received_notifications # verify what arrived +""" + +import asyncio +import json +from typing import Any, Dict, List, Optional + +import aiohttp +from aiohttp import web +from aiohttp.test_utils import TestServer + + +# -- Sample entity data ------------------------------------------------------- + +ENTITY_STATES: List[Dict[str, Any]] = [ + { + "entity_id": "light.bedroom", + "state": "on", + "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}, + "last_changed": "2025-01-15T10:30:00+00:00", + "last_updated": "2025-01-15T10:30:00+00:00", + }, + { + "entity_id": "light.kitchen", + "state": "off", + "attributes": {"friendly_name": "Kitchen Light"}, + "last_changed": "2025-01-15T09:00:00+00:00", + "last_updated": "2025-01-15T09:00:00+00:00", + }, + { + "entity_id": "sensor.temperature", + "state": "22.5", + "attributes": { + "friendly_name": "Kitchen Temperature", + "unit_of_measurement": "C", + }, + "last_changed": "2025-01-15T10:00:00+00:00", + "last_updated": "2025-01-15T10:00:00+00:00", + }, + { + "entity_id": "switch.fan", + "state": "on", + "attributes": {"friendly_name": "Living Room Fan"}, + "last_changed": "2025-01-15T08:00:00+00:00", + "last_updated": "2025-01-15T08:00:00+00:00", + }, + { + "entity_id": "climate.thermostat", + "state": "heat", + "attributes": { + "friendly_name": "Main Thermostat", + "current_temperature": 21, + "temperature": 23, + }, + "last_changed": "2025-01-15T07:00:00+00:00", + "last_updated": "2025-01-15T07:00:00+00:00", + }, +] + + +class FakeHAServer: + """In-process fake Home Assistant for integration tests. + + Parameters + ---------- + token : str + The expected Bearer token for authentication. + """ + + def __init__(self, token: str = "test-token-123"): + self.token = token + + # Observability -- tests inspect these after exercising the adapter. + self.received_service_calls: List[Dict[str, Any]] = [] + self.received_notifications: List[Dict[str, Any]] = [] + + # Control -- tests push events, server forwards them over WS. + self._event_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + + # Flag to simulate auth rejection. + self.reject_auth = False + + # Flag to simulate server errors. + self.force_500 = False + + # Internal bookkeeping. + self._app: Optional[web.Application] = None + self._server: Optional[TestServer] = None + self._ws_connections: List[web.WebSocketResponse] = [] + + # -- Public helpers -------------------------------------------------------- + + @property + def url(self) -> str: + """Base URL of the running server, e.g. ``http://127.0.0.1:12345``.""" + assert self._server is not None, "Server not started" + host = self._server.host + port = self._server.port + return f"http://{host}:{port}" + + async def push_event(self, event_data: Dict[str, Any]) -> None: + """Enqueue a state_changed event for delivery over WebSocket.""" + await self._event_queue.put(event_data) + + # -- Lifecycle ------------------------------------------------------------- + + async def start(self) -> None: + self._app = self._build_app() + self._server = TestServer(self._app) + await self._server.start_server() + + async def stop(self) -> None: + # Close any remaining WS connections. + for ws in self._ws_connections: + if not ws.closed: + await ws.close() + self._ws_connections.clear() + if self._server is not None: + await self._server.close() + + async def __aenter__(self) -> "FakeHAServer": + await self.start() + return self + + async def __aexit__(self, *exc) -> None: + await self.stop() + + # -- Application construction ---------------------------------------------- + + def _build_app(self) -> web.Application: + app = web.Application() + app.router.add_get("/api/websocket", self._handle_ws) + app.router.add_get("/api/states", self._handle_get_states) + app.router.add_get("/api/states/{entity_id}", self._handle_get_state) + # Notification endpoint must be registered before the generic service + # route so that it takes priority. + app.router.add_post( + "/api/services/persistent_notification/create", + self._handle_notification, + ) + app.router.add_post( + "/api/services/{domain}/{service}", + self._handle_call_service, + ) + return app + + # -- Auth helper ----------------------------------------------------------- + + def _check_rest_auth(self, request: web.Request) -> Optional[web.Response]: + """Return a 401 response if the Bearer token is wrong, else None.""" + auth = request.headers.get("Authorization", "") + if auth != f"Bearer {self.token}": + return web.Response(status=401, text="Unauthorized") + if self.force_500: + return web.Response(status=500, text="Internal Server Error") + return None + + # -- WebSocket handler ----------------------------------------------------- + + async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + self._ws_connections.append(ws) + + # Step 1: auth_required + await ws.send_json({"type": "auth_required", "ha_version": "2025.1.0"}) + + # Step 2: receive auth + msg = await ws.receive() + if msg.type != aiohttp.WSMsgType.TEXT: + await ws.close() + return ws + auth_msg = json.loads(msg.data) + + # Step 3: validate + if self.reject_auth or auth_msg.get("access_token") != self.token: + await ws.send_json({"type": "auth_invalid", "message": "Invalid token"}) + await ws.close() + return ws + + await ws.send_json({"type": "auth_ok", "ha_version": "2025.1.0"}) + + # Step 4: subscribe_events + msg = await ws.receive() + if msg.type != aiohttp.WSMsgType.TEXT: + await ws.close() + return ws + sub_msg = json.loads(msg.data) + sub_id = sub_msg.get("id", 1) + + # Step 5: ACK + await ws.send_json({ + "id": sub_id, + "type": "result", + "success": True, + "result": None, + }) + + # Step 6: push events from queue until closed + try: + while not ws.closed: + try: + event_data = await asyncio.wait_for( + self._event_queue.get(), timeout=0.1, + ) + await ws.send_json({ + "id": sub_id, + "type": "event", + "event": event_data, + }) + except asyncio.TimeoutError: + continue + except (ConnectionResetError, asyncio.CancelledError): + pass + + return ws + + # -- REST handlers --------------------------------------------------------- + + async def _handle_get_states(self, request: web.Request) -> web.Response: + err = self._check_rest_auth(request) + if err: + return err + return web.json_response(ENTITY_STATES) + + async def _handle_get_state(self, request: web.Request) -> web.Response: + err = self._check_rest_auth(request) + if err: + return err + entity_id = request.match_info["entity_id"] + for s in ENTITY_STATES: + if s["entity_id"] == entity_id: + return web.json_response(s) + return web.Response(status=404, text=f"Entity {entity_id} not found") + + async def _handle_notification(self, request: web.Request) -> web.Response: + err = self._check_rest_auth(request) + if err: + return err + body = await request.json() + self.received_notifications.append(body) + return web.json_response([]) + + async def _handle_call_service(self, request: web.Request) -> web.Response: + err = self._check_rest_auth(request) + if err: + return err + domain = request.match_info["domain"] + service = request.match_info["service"] + body = await request.json() + + self.received_service_calls.append({ + "domain": domain, + "service": service, + "data": body, + }) + + # Return affected entities (mimics real HA behaviour for light/switch). + affected = [] + entity_id = body.get("entity_id") + if entity_id: + new_state = "on" if service == "turn_on" else "off" + for s in ENTITY_STATES: + if s["entity_id"] == entity_id: + affected.append({ + "entity_id": entity_id, + "state": new_state, + "attributes": s.get("attributes", {}), + }) + break + + return web.json_response(affected) diff --git a/tests/integration/test_ha_integration.py b/tests/integration/test_ha_integration.py new file mode 100644 index 00000000..7f7329ba --- /dev/null +++ b/tests/integration/test_ha_integration.py @@ -0,0 +1,341 @@ +"""Integration tests for Home Assistant (tool + gateway). + +Spins up a real in-process fake HA server (HTTP + WebSocket) and exercises +the full adapter and tool handler paths over real TCP connections. +No mocks -- only real async I/O against a fake server. + +Run with: uv run pytest tests/integration/test_ha_integration.py -v +""" + +import asyncio + +import pytest + +pytestmark = pytest.mark.integration + +from unittest.mock import AsyncMock + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.homeassistant import HomeAssistantAdapter +from tests.fakes.fake_ha_server import FakeHAServer, ENTITY_STATES +from tools.homeassistant_tool import ( + _async_call_service, + _async_get_state, + _async_list_entities, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _adapter_for(server: FakeHAServer, **extra) -> HomeAssistantAdapter: + """Create an adapter pointed at the fake server.""" + config = PlatformConfig( + enabled=True, + token=server.token, + extra={"url": server.url, **extra}, + ) + return HomeAssistantAdapter(config) + + +# --------------------------------------------------------------------------- +# 1. Gateway -- WebSocket lifecycle +# --------------------------------------------------------------------------- + + +class TestGatewayWebSocket: + @pytest.mark.asyncio + async def test_connect_auth_subscribe(self): + """Full WS handshake succeeds: auth_required -> auth -> auth_ok -> subscribe -> ACK.""" + async with FakeHAServer() as server: + adapter = _adapter_for(server) + connected = await adapter.connect() + assert connected is True + assert adapter._running is True + assert adapter._ws is not None + assert not adapter._ws.closed + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_connect_auth_rejected(self): + """connect() returns False when the server rejects auth.""" + async with FakeHAServer() as server: + server.reject_auth = True + adapter = _adapter_for(server) + connected = await adapter.connect() + assert connected is False + + @pytest.mark.asyncio + async def test_event_received_and_forwarded(self): + """Server pushes event -> adapter calls handle_message with correct MessageEvent.""" + async with FakeHAServer() as server: + adapter = _adapter_for(server) + adapter.handle_message = AsyncMock() + + await adapter.connect() + + # Push a state_changed event + await server.push_event({ + "data": { + "entity_id": "light.bedroom", + "old_state": {"state": "off", "attributes": {}}, + "new_state": { + "state": "on", + "attributes": {"friendly_name": "Bedroom Light"}, + }, + } + }) + + # Wait for the adapter to process it + for _ in range(50): + if adapter.handle_message.call_count > 0: + break + await asyncio.sleep(0.05) + + assert adapter.handle_message.call_count == 1 + msg_event = adapter.handle_message.call_args[0][0] + assert "Bedroom Light" in msg_event.text + assert "turned on" in msg_event.text + assert msg_event.source.platform == Platform.HOMEASSISTANT + + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_event_filtering_ignores_unwatched(self): + """Events outside watch_domains are silently dropped.""" + async with FakeHAServer() as server: + adapter = _adapter_for(server, watch_domains=["climate"]) + adapter.handle_message = AsyncMock() + + await adapter.connect() + + # Push a light event (not in watch_domains) + await server.push_event({ + "data": { + "entity_id": "light.bedroom", + "old_state": {"state": "off", "attributes": {}}, + "new_state": { + "state": "on", + "attributes": {"friendly_name": "Bedroom Light"}, + }, + } + }) + + await asyncio.sleep(0.5) + assert adapter.handle_message.call_count == 0 + + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_disconnect_closes_cleanly(self): + """disconnect() cancels listener and closes WebSocket.""" + async with FakeHAServer() as server: + adapter = _adapter_for(server) + await adapter.connect() + ws_ref = adapter._ws + + await adapter.disconnect() + + assert adapter._running is False + assert adapter._listen_task is None + assert adapter._ws is None + # The original WS reference should be closed + assert ws_ref.closed + + +# --------------------------------------------------------------------------- +# 2. REST tool handlers (real HTTP against fake server) +# --------------------------------------------------------------------------- + + +class TestToolRest: + """Call the async tool functions directly against the fake server. + + Note: we call ``_async_*`` instead of the sync ``_handle_*`` wrappers + because the sync wrappers use ``_run_async`` which blocks the event + loop, deadlocking with the in-process fake server. The async functions + are the real logic; the sync wrappers are trivial bridge code already + covered by unit tests. + """ + + @pytest.mark.asyncio + async def test_list_entities_returns_all(self, monkeypatch): + """_async_list_entities returns all entities from the fake server.""" + async with FakeHAServer() as server: + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", server.token, + ) + + result = await _async_list_entities() + + assert result["count"] == len(ENTITY_STATES) + ids = {e["entity_id"] for e in result["entities"]} + assert "light.bedroom" in ids + assert "climate.thermostat" in ids + + @pytest.mark.asyncio + async def test_list_entities_domain_filter(self, monkeypatch): + """Domain filter is applied after fetching from server.""" + async with FakeHAServer() as server: + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", server.token, + ) + + result = await _async_list_entities(domain="light") + + assert result["count"] == 2 + for e in result["entities"]: + assert e["entity_id"].startswith("light.") + + @pytest.mark.asyncio + async def test_get_state_single_entity(self, monkeypatch): + """_async_get_state returns full entity details.""" + async with FakeHAServer() as server: + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", server.token, + ) + + result = await _async_get_state("light.bedroom") + + assert result["entity_id"] == "light.bedroom" + assert result["state"] == "on" + assert result["attributes"]["brightness"] == 200 + assert result["last_changed"] is not None + + @pytest.mark.asyncio + async def test_get_state_not_found(self, monkeypatch): + """Non-existent entity raises an aiohttp error (404).""" + import aiohttp as _aiohttp + + async with FakeHAServer() as server: + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", server.token, + ) + + with pytest.raises(_aiohttp.ClientResponseError) as exc_info: + await _async_get_state("light.nonexistent") + assert exc_info.value.status == 404 + + @pytest.mark.asyncio + async def test_call_service_turn_on(self, monkeypatch): + """_async_call_service sends correct payload and server records it.""" + async with FakeHAServer() as server: + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", server.token, + ) + + result = await _async_call_service( + domain="light", + service="turn_on", + entity_id="light.bedroom", + data={"brightness": 255}, + ) + + assert result["success"] is True + assert result["service"] == "light.turn_on" + assert len(result["affected_entities"]) == 1 + assert result["affected_entities"][0]["state"] == "on" + + # Verify fake server recorded the call + assert len(server.received_service_calls) == 1 + call = server.received_service_calls[0] + assert call["domain"] == "light" + assert call["service"] == "turn_on" + assert call["data"]["entity_id"] == "light.bedroom" + assert call["data"]["brightness"] == 255 + + +# --------------------------------------------------------------------------- +# 3. send() -- REST notification +# --------------------------------------------------------------------------- + + +class TestSendNotification: + @pytest.mark.asyncio + async def test_send_notification_delivered(self): + """Adapter send() delivers notification to fake server REST endpoint.""" + async with FakeHAServer() as server: + adapter = _adapter_for(server) + + result = await adapter.send("ha_events", "Test notification from agent") + + assert result.success is True + assert len(server.received_notifications) == 1 + notif = server.received_notifications[0] + assert notif["title"] == "Hermes Agent" + assert notif["message"] == "Test notification from agent" + + @pytest.mark.asyncio + async def test_send_auth_failure(self): + """send() returns failure when token is wrong.""" + async with FakeHAServer() as server: + config = PlatformConfig( + enabled=True, + token="wrong-token", + extra={"url": server.url}, + ) + adapter = HomeAssistantAdapter(config) + + result = await adapter.send("ha_events", "Should fail") + + assert result.success is False + assert "401" in result.error + + +# --------------------------------------------------------------------------- +# 4. Auth and error cases +# --------------------------------------------------------------------------- + + +class TestAuthAndErrors: + @pytest.mark.asyncio + async def test_rest_unauthorized(self, monkeypatch): + """Async function raises on 401 when token is wrong.""" + import aiohttp as _aiohttp + + async with FakeHAServer() as server: + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", "bad-token", + ) + + with pytest.raises(_aiohttp.ClientResponseError) as exc_info: + await _async_list_entities() + assert exc_info.value.status == 401 + + @pytest.mark.asyncio + async def test_rest_server_error(self, monkeypatch): + """Async function raises on 500 response.""" + import aiohttp as _aiohttp + + async with FakeHAServer() as server: + server.force_500 = True + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_URL", server.url, + ) + monkeypatch.setattr( + "tools.homeassistant_tool._HASS_TOKEN", server.token, + ) + + with pytest.raises(_aiohttp.ClientResponseError) as exc_info: + await _async_list_entities() + assert exc_info.value.status == 500 From 2390728cc38b1236279820971439e74f4d88b8ff Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Sat, 28 Feb 2026 15:12:18 +0300 Subject: [PATCH 3/5] fix: resolve 4 bugs found in HA integration code review - Auto-authorize HA events in gateway (system-generated, not user messages) - Guard _read_events against None/closed WebSocket after failed reconnect - Use UUID for send() message_id instead of polluting WS sequence counter - entity_id parameter now takes precedence over data["entity_id"] --- gateway/platforms/homeassistant.py | 5 ++++- gateway/run.py | 6 ++++++ tests/tools/test_homeassistant_tool.py | 6 +++--- tools/homeassistant_tool.py | 5 +++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/gateway/platforms/homeassistant.py b/gateway/platforms/homeassistant.py index 749cdf1e..08dfa099 100644 --- a/gateway/platforms/homeassistant.py +++ b/gateway/platforms/homeassistant.py @@ -17,6 +17,7 @@ import json import logging import os import time +import uuid from datetime import datetime from typing import Any, Dict, List, Optional, Set @@ -228,6 +229,8 @@ class HomeAssistantAdapter(BasePlatformAdapter): async def _read_events(self) -> None: """Read events from WebSocket until disconnected.""" + if self._ws is None or self._ws.closed: + return async for ws_msg in self._ws: if ws_msg.type == aiohttp.WSMsgType.TEXT: try: @@ -390,7 +393,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): timeout=aiohttp.ClientTimeout(total=10), ) as resp: if resp.status < 300: - return SendResult(success=True, message_id=str(self._next_id())) + return SendResult(success=True, message_id=uuid.uuid4().hex[:12]) else: body = await resp.text() return SendResult(success=False, error=f"HTTP {resp.status}: {body}") diff --git a/gateway/run.py b/gateway/run.py index 76ed3666..198629ce 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -490,6 +490,12 @@ class GatewayRunner: 4. Global allow-all (GATEWAY_ALLOW_ALL_USERS=true) 5. Default: deny """ + # Home Assistant events are system-generated (state changes), not + # user-initiated messages. The HASS_TOKEN already authenticates the + # connection, so HA events are always authorized. + if source.platform == Platform.HOMEASSISTANT: + return True + user_id = source.user_id if not user_id: return False diff --git a/tests/tools/test_homeassistant_tool.py b/tests/tools/test_homeassistant_tool.py index 6235474e..b57df069 100644 --- a/tests/tools/test_homeassistant_tool.py +++ b/tests/tools/test_homeassistant_tool.py @@ -130,13 +130,13 @@ class TestBuildServicePayload: payload = _build_service_payload() assert payload == {} - def test_data_does_not_overwrite_entity_id(self): + def test_entity_id_param_takes_precedence_over_data(self): payload = _build_service_payload( entity_id="light.a", data={"entity_id": "light.b"}, ) - # data.update overwrites entity_id set earlier - assert payload["entity_id"] == "light.b" + # explicit entity_id parameter wins over data["entity_id"] + assert payload["entity_id"] == "light.a" # --------------------------------------------------------------------------- diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py index 4a01382f..b351cfec 100644 --- a/tools/homeassistant_tool.py +++ b/tools/homeassistant_tool.py @@ -106,10 +106,11 @@ def _build_service_payload( ) -> Dict[str, Any]: """Build the JSON payload for a HA service call.""" payload: Dict[str, Any] = {} - if entity_id: - payload["entity_id"] = entity_id if data: payload.update(data) + # entity_id parameter takes precedence over data["entity_id"] + if entity_id: + payload["entity_id"] = entity_id return payload From dfd50ceccd8ff6b743bc5f23a2dff0d2ac5aa3b9 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Sat, 28 Feb 2026 18:01:13 +0300 Subject: [PATCH 4/5] fix: preserve Gemini thought_signature in tool call messages Gemini 3 thinking models attach extra_content with thought_signature to function call responses. This must be echoed back on subsequent API calls or the server rejects with a 400 error. The assistant message builder was dropping this field, causing all Gemini 3 Flash/Pro tool-calling flows to fail after the first function call. --- run_agent.py | 18 ++++++++++++++---- tests/test_run_agent.py | 18 ++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/run_agent.py b/run_agent.py index 59a547f0..3a939d16 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1369,8 +1369,9 @@ class AIAgent: ] if assistant_message.tool_calls: - msg["tool_calls"] = [ - { + tc_list = [] + for tool_call in assistant_message.tool_calls: + tc_dict = { "id": tool_call.id, "type": tool_call.type, "function": { @@ -1378,8 +1379,17 @@ class AIAgent: "arguments": tool_call.function.arguments } } - for tool_call in assistant_message.tool_calls - ] + # Preserve extra_content (e.g. Gemini thought_signature) so it + # is sent back on subsequent API calls. Without this, Gemini 3 + # thinking models reject the request with a 400 error. + extra = getattr(tool_call, "extra_content", None) + if extra is not None: + # Convert Pydantic models to plain dicts for JSON safety + if hasattr(extra, "model_dump"): + extra = extra.model_dump() + tc_dict["extra_content"] = extra + tc_list.append(tc_dict) + msg["tool_calls"] = tc_list return msg diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index 2d370393..ad90bd27 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -546,6 +546,24 @@ class TestBuildAssistantMessage: result = agent._build_assistant_message(msg, "stop") assert result["content"] == "" + def test_tool_call_extra_content_preserved(self, agent): + """Gemini thinking models attach extra_content with thought_signature + to tool calls. This must be preserved so subsequent API calls include it.""" + tc = _mock_tool_call(name="get_weather", arguments='{"city":"NYC"}', call_id="c2") + tc.extra_content = {"google": {"thought_signature": "abc123"}} + msg = _mock_assistant_msg(content="", tool_calls=[tc]) + result = agent._build_assistant_message(msg, "tool_calls") + assert result["tool_calls"][0]["extra_content"] == { + "google": {"thought_signature": "abc123"} + } + + def test_tool_call_without_extra_content(self, agent): + """Standard tool calls (no thinking model) should not have extra_content.""" + tc = _mock_tool_call(name="web_search", arguments='{}', call_id="c3") + msg = _mock_assistant_msg(content="", tool_calls=[tc]) + result = agent._build_assistant_message(msg, "tool_calls") + assert "extra_content" not in result["tool_calls"][0] + class TestFormatToolsForSystemMessage: def test_no_tools_returns_empty_array(self, agent): From 25fb9aafcbf1530f13b4df2e52a817a6a43dfaa5 Mon Sep 17 00:00:00 2001 From: 0xbyt4 <35742124+0xbyt4@users.noreply.github.com> Date: Sun, 1 Mar 2026 11:53:50 +0300 Subject: [PATCH 5/5] fix: add service domain blocklist and entity_id validation to HA tools Block dangerous HA service domains (shell_command, command_line, python_script, pyscript, hassio, rest_command) that allow arbitrary code execution or SSRF. Add regex validation for entity_id to prevent path traversal attacks. 17 new tests covering both security features. --- tests/tools/test_homeassistant_tool.py | 92 ++++++++++++++++++++++++++ tools/homeassistant_tool.py | 27 ++++++++ 2 files changed, 119 insertions(+) diff --git a/tests/tools/test_homeassistant_tool.py b/tests/tools/test_homeassistant_tool.py index b57df069..b136b565 100644 --- a/tests/tools/test_homeassistant_tool.py +++ b/tests/tools/test_homeassistant_tool.py @@ -16,6 +16,8 @@ from tools.homeassistant_tool import ( _get_headers, _handle_get_state, _handle_call_service, + _BLOCKED_DOMAINS, + _ENTITY_ID_RE, ) @@ -211,6 +213,96 @@ class TestHandlerValidation: assert "error" in result +# --------------------------------------------------------------------------- +# Security: domain blocklist +# --------------------------------------------------------------------------- + + +class TestDomainBlocklist: + """Verify dangerous HA service domains are blocked.""" + + @pytest.mark.parametrize("domain", sorted(_BLOCKED_DOMAINS)) + def test_blocked_domain_rejected(self, domain): + result = json.loads(_handle_call_service({ + "domain": domain, "service": "any_service" + })) + assert "error" in result + assert "blocked" in result["error"].lower() + + def test_safe_domain_not_blocked(self): + """Safe domains like 'light' should not be blocked (will fail on network, not blocklist).""" + # This will try to make a real HTTP call and fail, but the important thing + # is it does NOT return a "blocked" error + result = json.loads(_handle_call_service({ + "domain": "light", "service": "turn_on", "entity_id": "light.test" + })) + # Should fail with a network/connection error, not a "blocked" error + if "error" in result: + assert "blocked" not in result["error"].lower() + + def test_blocked_domains_include_shell_command(self): + assert "shell_command" in _BLOCKED_DOMAINS + + def test_blocked_domains_include_hassio(self): + assert "hassio" in _BLOCKED_DOMAINS + + def test_blocked_domains_include_rest_command(self): + assert "rest_command" in _BLOCKED_DOMAINS + + +# --------------------------------------------------------------------------- +# Security: entity_id validation +# --------------------------------------------------------------------------- + + +class TestEntityIdValidation: + """Verify entity_id format validation prevents path traversal.""" + + def test_valid_entity_id_accepted(self): + assert _ENTITY_ID_RE.match("light.bedroom") + assert _ENTITY_ID_RE.match("sensor.temperature_1") + assert _ENTITY_ID_RE.match("binary_sensor.motion") + assert _ENTITY_ID_RE.match("climate.main_thermostat") + + def test_path_traversal_rejected(self): + assert _ENTITY_ID_RE.match("../../config") is None + assert _ENTITY_ID_RE.match("light/../../../etc/passwd") is None + assert _ENTITY_ID_RE.match("../api/config") is None + + def test_special_chars_rejected(self): + assert _ENTITY_ID_RE.match("light.bed room") is None # space + assert _ENTITY_ID_RE.match("light.bed;rm -rf") is None # semicolon + assert _ENTITY_ID_RE.match("light.bed/room") is None # slash + assert _ENTITY_ID_RE.match("LIGHT.BEDROOM") is None # uppercase + + def test_missing_domain_rejected(self): + assert _ENTITY_ID_RE.match(".bedroom") is None + assert _ENTITY_ID_RE.match("bedroom") is None + + def test_get_state_rejects_invalid_entity_id(self): + result = json.loads(_handle_get_state({"entity_id": "../../config"})) + assert "error" in result + assert "Invalid entity_id" in result["error"] + + def test_call_service_rejects_invalid_entity_id(self): + result = json.loads(_handle_call_service({ + "domain": "light", + "service": "turn_on", + "entity_id": "../../../etc/passwd", + })) + assert "error" in result + assert "Invalid entity_id" in result["error"] + + def test_call_service_allows_no_entity_id(self): + """Some services (like scene.turn_on) don't need entity_id.""" + # Will fail on network, but should NOT fail on entity_id validation + result = json.loads(_handle_call_service({ + "domain": "scene", "service": "turn_on" + })) + if "error" in result: + assert "Invalid entity_id" not in result["error"] + + # --------------------------------------------------------------------------- # Availability check # --------------------------------------------------------------------------- diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py index b351cfec..17729610 100644 --- a/tools/homeassistant_tool.py +++ b/tools/homeassistant_tool.py @@ -13,6 +13,7 @@ import asyncio import json import logging import os +import re from typing import Any, Dict, Optional logger = logging.getLogger(__name__) @@ -24,6 +25,21 @@ logger = logging.getLogger(__name__) _HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/") _HASS_TOKEN: str = os.getenv("HASS_TOKEN", "") +# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1") +_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$") + +# Service domains blocked for security -- these allow arbitrary code/command +# execution on the HA host or enable SSRF attacks on the local network. +# HA provides zero service-level access control; all safety must be in our layer. +_BLOCKED_DOMAINS = frozenset({ + "shell_command", # arbitrary shell commands as root in HA container + "command_line", # sensors/switches that execute shell commands + "python_script", # sandboxed but can escalate via hass.services.call() + "pyscript", # scripting integration with broader access + "hassio", # addon control, host shutdown/reboot, stdin to containers + "rest_command", # HTTP requests from HA server (SSRF vector) +}) + def _get_headers() -> Dict[str, str]: """Return authorization headers for HA REST API.""" @@ -198,6 +214,8 @@ def _handle_get_state(args: dict, **kw) -> str: entity_id = args.get("entity_id", "") if not entity_id: return json.dumps({"error": "Missing required parameter: entity_id"}) + if not _ENTITY_ID_RE.match(entity_id): + return json.dumps({"error": f"Invalid entity_id format: {entity_id}"}) try: result = _run_async(_async_get_state(entity_id)) return json.dumps({"result": result}) @@ -213,7 +231,16 @@ def _handle_call_service(args: dict, **kw) -> str: if not domain or not service: return json.dumps({"error": "Missing required parameters: domain and service"}) + if domain in _BLOCKED_DOMAINS: + return json.dumps({ + "error": f"Service domain '{domain}' is blocked for security. " + f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}" + }) + entity_id = args.get("entity_id") + if entity_id and not _ENTITY_ID_RE.match(entity_id): + return json.dumps({"error": f"Invalid entity_id format: {entity_id}"}) + data = args.get("data") try: result = _run_async(_async_call_service(domain, service, entity_id, data))