from __future__ import annotations from dataclasses import dataclass from typing import AsyncIterator from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from sdk.interface import MessageChunk, MessageResponse, PlatformError def build_thread_key(platform: str, user_id: str, chat_id: str) -> str: return f"{len(platform)}:{platform}{len(user_id)}:{user_id}{len(chat_id)}:{chat_id}" @dataclass(frozen=True, slots=True) class AgentSessionConfig: base_ws_url: str timeout_seconds: float = 30.0 class AgentSessionClient: def __init__(self, config: AgentSessionConfig) -> None: self._config = config async def send_message(self, *, thread_key: str, text: str) -> MessageResponse: response_parts: list[str] = [] tokens_used = 0 async for chunk in self.stream_message(thread_key=thread_key, text=text): if chunk.delta: response_parts.append(chunk.delta) if chunk.finished: tokens_used = chunk.tokens_used return MessageResponse( message_id=thread_key, response="".join(response_parts), tokens_used=tokens_used, finished=True, ) async def stream_message(self, *, thread_key: str, text: str) -> AsyncIterator[MessageChunk]: import aiohttp async with aiohttp.ClientSession() as session: async with session.ws_connect( self._ws_url(thread_key), heartbeat=30, ) as ws: status = await ws.receive_json(timeout=self._config.timeout_seconds) if status.get("type") != "STATUS": raise PlatformError("Agent did not send STATUS", code="AGENT_PROTOCOL_ERROR") await ws.send_json({"type": "USER_MESSAGE", "text": text}) while True: payload = await ws.receive_json(timeout=self._config.timeout_seconds) msg_type = payload.get("type") if msg_type == "AGENT_EVENT_TEXT_CHUNK": yield MessageChunk( message_id=thread_key, delta=payload["text"], finished=False, ) elif msg_type == "AGENT_EVENT_END": yield MessageChunk( message_id=thread_key, delta="", finished=True, tokens_used=payload.get("tokens_used", 0), ) return elif msg_type == "ERROR": raise PlatformError( payload.get("details", "Agent error"), code=payload.get("code", "AGENT_ERROR"), ) elif msg_type == "GRACEFUL_DISCONNECT": raise PlatformError( "Agent disconnected gracefully", code="GRACEFUL_DISCONNECT", ) else: raise PlatformError( f"Unexpected agent message: {payload}", code="AGENT_PROTOCOL_ERROR", ) def _ws_url(self, thread_key: str) -> str: parts = urlsplit(self._config.base_ws_url) query = dict(parse_qsl(parts.query, keep_blank_values=True)) query["thread_id"] = thread_key return urlunsplit(parts._replace(query=urlencode(query)))