surfaces/tests/platform/test_agent_session.py

93 lines
3.2 KiB
Python

import pytest
from aiohttp import web
from sdk.interface import MessageChunk, MessageResponse
from sdk.agent_session import AgentSessionClient, AgentSessionConfig, build_thread_key
def test_build_thread_key_uses_platform_user_and_chat_id():
assert build_thread_key("matrix", "@alice:example.org", "C1") == "6:matrix18:@alice:example.org2:C1"
def test_build_thread_key_does_not_collide_when_user_id_contains_colons():
left = build_thread_key("matrix", "@alice:example.org", "C1")
right = build_thread_key("matrix", "@alice", "example.org:C1")
assert left != right
@pytest.mark.asyncio
async def test_stream_message_yields_text_chunks_and_end(aiohttp_server):
async def handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
assert request.query["thread_id"] == "matrix:@alice:example.org:C1"
await ws.send_json({"type": "STATUS"})
message = await ws.receive_json()
assert message == {"type": "USER_MESSAGE", "text": "hello"}
await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "hel"})
await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "lo"})
await ws.send_json({"type": "AGENT_EVENT_END", "tokens_used": 7})
await ws.close()
return ws
app = web.Application()
app.router.add_get("/agent_ws/", handler)
server = await aiohttp_server(app)
client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/"))))
chunks = []
async for chunk in client.stream_message(
thread_key="matrix:@alice:example.org:C1",
text="hello",
):
chunks.append(chunk)
assert chunks == [
MessageChunk(message_id="matrix:@alice:example.org:C1", delta="hel", finished=False, tokens_used=0),
MessageChunk(message_id="matrix:@alice:example.org:C1", delta="lo", finished=False, tokens_used=0),
MessageChunk(message_id="matrix:@alice:example.org:C1", delta="", finished=True, tokens_used=7),
]
@pytest.mark.asyncio
async def test_send_message_collects_streamed_chunks_and_tokens(aiohttp_server):
async def handler(request):
ws = web.WebSocketResponse()
await ws.prepare(request)
assert request.query["thread_id"] == "matrix:@alice:example.org:C1"
await ws.send_json({"type": "STATUS"})
message = await ws.receive_json()
assert message == {"type": "USER_MESSAGE", "text": "hello world"}
await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "hello "})
await ws.send_json({"type": "AGENT_EVENT_TEXT_CHUNK", "text": "world"})
await ws.send_json({"type": "AGENT_EVENT_END", "tokens_used": 11})
await ws.close()
return ws
app = web.Application()
app.router.add_get("/agent_ws/", handler)
server = await aiohttp_server(app)
client = AgentSessionClient(AgentSessionConfig(base_ws_url=str(server.make_url("/agent_ws/"))))
result = await client.send_message(
thread_key="matrix:@alice:example.org:C1",
text="hello world",
)
assert result == MessageResponse(
message_id="matrix:@alice:example.org:C1",
response="hello world",
tokens_used=11,
finished=True,
)