Serialize Matrix chat sends
This commit is contained in:
parent
4533118b68
commit
17d580096b
4 changed files with 281 additions and 79 deletions
|
|
@ -21,17 +21,12 @@ from adapter.matrix.converter import from_room_event
|
||||||
from adapter.matrix.handlers import register_matrix_handlers
|
from adapter.matrix.handlers import register_matrix_handlers
|
||||||
from adapter.matrix.handlers.context_commands import (
|
from adapter.matrix.handlers.context_commands import (
|
||||||
LOAD_PROMPT,
|
LOAD_PROMPT,
|
||||||
SAVE_PROMPT,
|
|
||||||
_call_reset_endpoint,
|
|
||||||
_sanitize_session_name,
|
|
||||||
)
|
)
|
||||||
from adapter.matrix.handlers.auth import handle_invite
|
from adapter.matrix.handlers.auth import handle_invite, provision_workspace_chat
|
||||||
from adapter.matrix.room_router import resolve_chat_id
|
from adapter.matrix.room_router import resolve_chat_id
|
||||||
from adapter.matrix.store import (
|
from adapter.matrix.store import (
|
||||||
clear_load_pending,
|
clear_load_pending,
|
||||||
clear_reset_pending,
|
|
||||||
get_load_pending,
|
get_load_pending,
|
||||||
get_reset_pending,
|
|
||||||
get_room_meta,
|
get_room_meta,
|
||||||
set_pending_confirm,
|
set_pending_confirm,
|
||||||
)
|
)
|
||||||
|
|
@ -153,11 +148,12 @@ class MatrixBot:
|
||||||
await self._send_all(room.room_id, outgoing)
|
await self._send_all(room.room_id, outgoing)
|
||||||
return
|
return
|
||||||
|
|
||||||
reset_pending = await get_reset_pending(self.runtime.store, sender, room.room_id)
|
room_meta = await get_room_meta(self.runtime.store, room.room_id)
|
||||||
if reset_pending is not None and (body in {"!yes", "!no"} or body.startswith("!save ")):
|
if room_meta is None:
|
||||||
outgoing = await self._handle_reset_selection(sender, room.room_id, body)
|
outgoing = await self._bootstrap_unregistered_room(room, sender)
|
||||||
await self._send_all(room.room_id, outgoing)
|
if outgoing:
|
||||||
return
|
await self._send_all(room.room_id, outgoing)
|
||||||
|
return
|
||||||
|
|
||||||
chat_id = await resolve_chat_id(self.runtime.store, room.room_id, sender)
|
chat_id = await resolve_chat_id(self.runtime.store, room.room_id, sender)
|
||||||
incoming = from_room_event(event, room_id=room.room_id, chat_id=chat_id)
|
incoming = from_room_event(event, room_id=room.room_id, chat_id=chat_id)
|
||||||
|
|
@ -181,6 +177,57 @@ class MatrixBot:
|
||||||
]
|
]
|
||||||
await self._send_all(room.room_id, outgoing)
|
await self._send_all(room.room_id, outgoing)
|
||||||
|
|
||||||
|
async def _bootstrap_unregistered_room(
|
||||||
|
self,
|
||||||
|
room: MatrixRoom,
|
||||||
|
sender: str,
|
||||||
|
) -> list[OutgoingEvent] | None:
|
||||||
|
if not hasattr(self.client, "room_create") or not hasattr(self.client, "room_put_state"):
|
||||||
|
return None
|
||||||
|
display_name = getattr(room, "display_name", None) or sender
|
||||||
|
try:
|
||||||
|
created = await provision_workspace_chat(
|
||||||
|
self.client,
|
||||||
|
sender,
|
||||||
|
display_name,
|
||||||
|
self.runtime.platform,
|
||||||
|
self.runtime.store,
|
||||||
|
self.runtime.auth_mgr,
|
||||||
|
self.runtime.chat_mgr,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"matrix_unregistered_room_bootstrap_failed",
|
||||||
|
room_id=room.room_id,
|
||||||
|
sender=sender,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=room.room_id,
|
||||||
|
text="Не удалось подготовить рабочий чат. Попробуйте ещё раз позже.",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
welcome = (
|
||||||
|
f"Привет, {created['user'].display_name or sender}! Пиши — я здесь.\n\n"
|
||||||
|
"Команды: !new · !chats · !rename · !archive · !context · !save · !load · !help"
|
||||||
|
)
|
||||||
|
await self.client.room_send(
|
||||||
|
created["chat_room_id"],
|
||||||
|
"m.room.message",
|
||||||
|
{"msgtype": "m.text", "body": welcome},
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
OutgoingMessage(
|
||||||
|
chat_id=room.room_id,
|
||||||
|
text=(
|
||||||
|
f"Создал рабочий чат {created['room_name']} ({created['chat_id']}) "
|
||||||
|
"и добавил его в пространство Lambda. Открой приглашённую комнату для продолжения."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
async def _handle_load_selection(
|
async def _handle_load_selection(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
@ -217,45 +264,7 @@ class MatrixBot:
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("load_agent_call_failed", error=str(exc))
|
logger.warning("load_agent_call_failed", error=str(exc))
|
||||||
return [OutgoingMessage(chat_id=room_id, text=f"Ошибка при загрузке: {exc}")]
|
return [OutgoingMessage(chat_id=room_id, text=f"Ошибка при загрузке: {exc}")]
|
||||||
return [OutgoingMessage(chat_id=room_id, text=f"Загрузка: {name}")]
|
return [OutgoingMessage(chat_id=room_id, text=f"Запрос на загрузку отправлен агенту: {name}")]
|
||||||
|
|
||||||
async def _handle_reset_selection(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
room_id: str,
|
|
||||||
text: str,
|
|
||||||
) -> list[OutgoingEvent]:
|
|
||||||
agent_base_url = os.environ.get("AGENT_BASE_URL", "http://127.0.0.1:8000")
|
|
||||||
prototype_state = getattr(self.runtime.platform, "_prototype_state", None)
|
|
||||||
await clear_reset_pending(self.runtime.store, user_id, room_id)
|
|
||||||
|
|
||||||
if text == "!no":
|
|
||||||
return [OutgoingMessage(chat_id=room_id, text="Отменено.")]
|
|
||||||
|
|
||||||
if text.startswith("!save "):
|
|
||||||
name = _sanitize_session_name(text[len("!save ") :].strip())
|
|
||||||
if name is None:
|
|
||||||
return [
|
|
||||||
OutgoingMessage(
|
|
||||||
chat_id=room_id,
|
|
||||||
text="Имя сохранения может содержать только буквы, цифры, _ и -.",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
await self.runtime.platform.send_message(
|
|
||||||
user_id,
|
|
||||||
room_id,
|
|
||||||
SAVE_PROMPT.format(name=name),
|
|
||||||
)
|
|
||||||
if prototype_state is not None:
|
|
||||||
await prototype_state.add_saved_session(user_id, name)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("save_before_reset_failed", error=str(exc))
|
|
||||||
return [OutgoingMessage(chat_id=room_id, text=f"Ошибка при сохранении: {exc}")]
|
|
||||||
|
|
||||||
if prototype_state is not None:
|
|
||||||
await prototype_state.clear_current_session(user_id)
|
|
||||||
return await _call_reset_endpoint(agent_base_url, room_id)
|
|
||||||
|
|
||||||
async def on_member(self, room: MatrixRoom, event: RoomMemberEvent) -> None:
|
async def on_member(self, room: MatrixRoom, event: RoomMemberEvent) -> None:
|
||||||
if getattr(event, "sender", None) == self.client.user_id:
|
if getattr(event, "sender", None) == self.client.user_id:
|
||||||
|
|
@ -373,12 +382,11 @@ async def main() -> None:
|
||||||
request_timeout=client_config.request_timeout,
|
request_timeout=client_config.request_timeout,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if isinstance(runtime.platform, RealPlatformClient):
|
|
||||||
await runtime.platform.agent_api.connect()
|
|
||||||
await client.sync_forever(timeout=30000, since=since_token)
|
await client.sync_forever(timeout=30000, since=since_token)
|
||||||
finally:
|
finally:
|
||||||
if isinstance(runtime.platform, RealPlatformClient):
|
close = getattr(runtime.platform, "close", None)
|
||||||
await runtime.platform.agent_api.close()
|
if callable(close):
|
||||||
|
await close()
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
36
sdk/real.py
36
sdk/real.py
|
|
@ -20,6 +20,7 @@ class RealPlatformClient(PlatformClient):
|
||||||
self._platform = platform
|
self._platform = platform
|
||||||
self._chat_apis: dict[str, AgentApiWrapper] = {}
|
self._chat_apis: dict[str, AgentApiWrapper] = {}
|
||||||
self._chat_api_lock = asyncio.Lock()
|
self._chat_api_lock = asyncio.Lock()
|
||||||
|
self._chat_send_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agent_api(self) -> AgentApiWrapper:
|
def agent_api(self) -> AgentApiWrapper:
|
||||||
|
|
@ -41,6 +42,14 @@ class RealPlatformClient(PlatformClient):
|
||||||
self._chat_apis[chat_key] = chat_api
|
self._chat_apis[chat_key] = chat_api
|
||||||
return chat_api
|
return chat_api
|
||||||
|
|
||||||
|
def _get_chat_send_lock(self, chat_id: str) -> asyncio.Lock:
|
||||||
|
chat_key = str(chat_id)
|
||||||
|
lock = self._chat_send_locks.get(chat_key)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
self._chat_send_locks[chat_key] = lock
|
||||||
|
return lock
|
||||||
|
|
||||||
async def get_or_create_user(
|
async def get_or_create_user(
|
||||||
self,
|
self,
|
||||||
external_id: str,
|
external_id: str,
|
||||||
|
|
@ -85,21 +94,23 @@ class RealPlatformClient(PlatformClient):
|
||||||
text: str,
|
text: str,
|
||||||
attachments: list[Attachment] | None = None,
|
attachments: list[Attachment] | None = None,
|
||||||
) -> AsyncIterator[MessageChunk]:
|
) -> AsyncIterator[MessageChunk]:
|
||||||
chat_api = await self._get_chat_api(chat_id)
|
lock = self._get_chat_send_lock(chat_id)
|
||||||
if hasattr(chat_api, "last_tokens_used"):
|
async with lock:
|
||||||
chat_api.last_tokens_used = 0
|
chat_api = await self._get_chat_api(chat_id)
|
||||||
async for event in chat_api.send_message(text):
|
if hasattr(chat_api, "last_tokens_used"):
|
||||||
|
chat_api.last_tokens_used = 0
|
||||||
|
async for event in chat_api.send_message(text):
|
||||||
|
yield MessageChunk(
|
||||||
|
message_id=user_id,
|
||||||
|
delta=event.text,
|
||||||
|
finished=False,
|
||||||
|
)
|
||||||
yield MessageChunk(
|
yield MessageChunk(
|
||||||
message_id=user_id,
|
message_id=user_id,
|
||||||
delta=event.text,
|
delta="",
|
||||||
finished=False,
|
finished=True,
|
||||||
|
tokens_used=getattr(chat_api, "last_tokens_used", 0),
|
||||||
)
|
)
|
||||||
yield MessageChunk(
|
|
||||||
message_id=user_id,
|
|
||||||
delta="",
|
|
||||||
finished=True,
|
|
||||||
tokens_used=getattr(chat_api, "last_tokens_used", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_settings(self, user_id: str) -> UserSettings:
|
async def get_settings(self, user_id: str) -> UserSettings:
|
||||||
return await self._prototype_state.get_settings(user_id)
|
return await self._prototype_state.get_settings(user_id)
|
||||||
|
|
@ -113,6 +124,7 @@ class RealPlatformClient(PlatformClient):
|
||||||
if callable(close):
|
if callable(close):
|
||||||
await close()
|
await close()
|
||||||
self._chat_apis.clear()
|
self._chat_apis.clear()
|
||||||
|
self._chat_send_locks.clear()
|
||||||
if not callable(getattr(self._agent_api, "for_chat", None)):
|
if not callable(getattr(self._agent_api, "for_chat", None)):
|
||||||
close = getattr(self._agent_api, "close", None)
|
close = getattr(self._agent_api, "close", None)
|
||||||
if callable(close):
|
if callable(close):
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ async def test_matrix_dispatcher_registers_custom_handlers():
|
||||||
user_id="u1", platform="matrix", chat_id=current_chat_id, command="settings_skills"
|
user_id="u1", platform="matrix", chat_id=current_chat_id, command="settings_skills"
|
||||||
)
|
)
|
||||||
result = await runtime.dispatcher.dispatch(skills)
|
result = await runtime.dispatcher.dispatch(skills)
|
||||||
assert any(isinstance(r, OutgoingMessage) and "!skill on/off" in r.text for r in result)
|
assert any(isinstance(r, OutgoingMessage) and "mvp" in r.text.lower() for r in result)
|
||||||
|
|
||||||
toggle = IncomingCallback(
|
toggle = IncomingCallback(
|
||||||
user_id="u1",
|
user_id="u1",
|
||||||
|
|
@ -54,7 +54,7 @@ async def test_matrix_dispatcher_registers_custom_handlers():
|
||||||
payload={"skill_index": 2},
|
payload={"skill_index": 2},
|
||||||
)
|
)
|
||||||
result = await runtime.dispatcher.dispatch(toggle)
|
result = await runtime.dispatcher.dispatch(toggle)
|
||||||
assert any(isinstance(r, OutgoingMessage) and "fetch-url" in r.text for r in result)
|
assert any(isinstance(r, OutgoingMessage) and "mvp" in r.text.lower() for r in result)
|
||||||
|
|
||||||
|
|
||||||
async def test_new_chat_creates_real_matrix_room_when_client_available():
|
async def test_new_chat_creates_real_matrix_room_when_client_available():
|
||||||
|
|
@ -226,7 +226,75 @@ async def test_bot_degrades_platform_errors_to_user_reply():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_mat11_settings_returns_dashboard():
|
async def test_unregistered_room_bootstraps_space_and_chat_on_first_message():
|
||||||
|
runtime = build_runtime(platform=MockPlatformClient())
|
||||||
|
await set_user_meta(runtime.store, "@alice:example.org", {"next_chat_index": 1})
|
||||||
|
space_resp = SimpleNamespace(room_id="!space:example.org")
|
||||||
|
chat_resp = SimpleNamespace(room_id="!chat1:example.org")
|
||||||
|
client = SimpleNamespace(
|
||||||
|
user_id="@bot:example.org",
|
||||||
|
room_create=AsyncMock(side_effect=[space_resp, chat_resp]),
|
||||||
|
room_put_state=AsyncMock(),
|
||||||
|
room_send=AsyncMock(),
|
||||||
|
)
|
||||||
|
bot = MatrixBot(client, runtime)
|
||||||
|
room = SimpleNamespace(room_id="!entry:example.org", display_name="Entry")
|
||||||
|
event = SimpleNamespace(sender="@alice:example.org", body="hello")
|
||||||
|
|
||||||
|
await bot.on_room_message(room, event)
|
||||||
|
|
||||||
|
assert client.room_create.await_count == 2
|
||||||
|
first_call = client.room_create.call_args_list[0]
|
||||||
|
second_call = client.room_create.call_args_list[1]
|
||||||
|
assert first_call.kwargs.get("space") is True
|
||||||
|
assert first_call.kwargs.get("invite") == ["@alice:example.org"]
|
||||||
|
assert second_call.kwargs.get("name") == "Чат 1"
|
||||||
|
assert second_call.kwargs.get("invite") == ["@alice:example.org"]
|
||||||
|
client.room_put_state.assert_awaited_once()
|
||||||
|
room_meta = await get_room_meta(runtime.store, "!chat1:example.org")
|
||||||
|
assert room_meta is not None
|
||||||
|
assert room_meta["chat_id"] == "C1"
|
||||||
|
user_meta = await get_user_meta(runtime.store, "@alice:example.org")
|
||||||
|
assert user_meta is not None
|
||||||
|
assert user_meta["space_id"] == "!space:example.org"
|
||||||
|
room_send_calls = client.room_send.await_args_list
|
||||||
|
assert any(call.args[0] == "!chat1:example.org" for call in room_send_calls)
|
||||||
|
assert any(call.args[0] == "!entry:example.org" for call in room_send_calls)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unregistered_room_creates_new_chat_in_existing_space():
|
||||||
|
runtime = build_runtime(platform=MockPlatformClient())
|
||||||
|
await set_user_meta(
|
||||||
|
runtime.store,
|
||||||
|
"@alice:example.org",
|
||||||
|
{"space_id": "!space:example.org", "next_chat_index": 4},
|
||||||
|
)
|
||||||
|
chat_resp = SimpleNamespace(room_id="!chat4:example.org")
|
||||||
|
client = SimpleNamespace(
|
||||||
|
user_id="@bot:example.org",
|
||||||
|
room_create=AsyncMock(return_value=chat_resp),
|
||||||
|
room_put_state=AsyncMock(),
|
||||||
|
room_send=AsyncMock(),
|
||||||
|
)
|
||||||
|
bot = MatrixBot(client, runtime)
|
||||||
|
room = SimpleNamespace(room_id="!entry:example.org", display_name="Entry")
|
||||||
|
event = SimpleNamespace(sender="@alice:example.org", body="hello")
|
||||||
|
|
||||||
|
await bot.on_room_message(room, event)
|
||||||
|
|
||||||
|
client.room_create.assert_awaited_once_with(
|
||||||
|
name="Чат 4",
|
||||||
|
visibility=RoomVisibility.private,
|
||||||
|
is_direct=False,
|
||||||
|
invite=["@alice:example.org"],
|
||||||
|
)
|
||||||
|
client.room_put_state.assert_awaited_once()
|
||||||
|
room_meta = await get_room_meta(runtime.store, "!chat4:example.org")
|
||||||
|
assert room_meta is not None
|
||||||
|
assert room_meta["chat_id"] == "C4"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_mat11_settings_returns_mvp_unavailable_message():
|
||||||
runtime = build_runtime(platform=MockPlatformClient())
|
runtime = build_runtime(platform=MockPlatformClient())
|
||||||
current_chat_id = "C9"
|
current_chat_id = "C9"
|
||||||
|
|
||||||
|
|
@ -238,15 +306,10 @@ async def test_mat11_settings_returns_dashboard():
|
||||||
)
|
)
|
||||||
result = await runtime.dispatcher.dispatch(settings_cmd)
|
result = await runtime.dispatcher.dispatch(settings_cmd)
|
||||||
|
|
||||||
assert len(result) >= 1
|
assert len(result) == 1
|
||||||
text = result[0].text
|
text = result[0].text
|
||||||
assert "Скиллы" in text or "скиллы" in text.lower()
|
assert "недоступна" in text.lower()
|
||||||
assert "Личность" in text
|
assert "mvp" in text.lower()
|
||||||
assert "Безопасность" in text
|
|
||||||
assert "Активные чаты" in text
|
|
||||||
assert "Изменить" not in text
|
|
||||||
assert "!connectors" not in text
|
|
||||||
assert "!whoami" not in text
|
|
||||||
|
|
||||||
|
|
||||||
async def test_mat12_help_returns_command_reference():
|
async def test_mat12_help_returns_command_reference():
|
||||||
|
|
@ -259,10 +322,26 @@ async def test_mat12_help_returns_command_reference():
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
text = result[0].text
|
text = result[0].text
|
||||||
assert "!new" in text
|
assert "!new" in text
|
||||||
|
assert "!chats" in text
|
||||||
assert "!rename" in text
|
assert "!rename" in text
|
||||||
assert "!archive" in text
|
assert "!archive" in text
|
||||||
assert "!settings" in text
|
assert "!context" in text
|
||||||
assert "!yes" in text
|
assert "!save" in text
|
||||||
|
assert "!load" in text
|
||||||
|
assert "!reset" not in text
|
||||||
|
assert "!settings" not in text
|
||||||
|
assert "!skills" not in text
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unknown_command_returns_helpful_message():
|
||||||
|
runtime = build_runtime(platform=MockPlatformClient())
|
||||||
|
|
||||||
|
result = await runtime.dispatcher.dispatch(
|
||||||
|
IncomingCommand(user_id="u1", platform="matrix", chat_id="C1", command="clear")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "неизвестная команда" in result[0].text.lower()
|
||||||
|
|
||||||
|
|
||||||
async def test_prepare_live_sync_returns_next_batch_from_bootstrap_sync():
|
async def test_prepare_live_sync_returns_next_batch_from_bootstrap_sync():
|
||||||
|
|
@ -302,3 +381,41 @@ async def test_build_runtime_uses_real_platform_when_matrix_backend_is_real(monk
|
||||||
|
|
||||||
assert isinstance(runtime.platform, RealPlatformClient)
|
assert isinstance(runtime.platform, RealPlatformClient)
|
||||||
assert runtime.platform.agent_api.url == "ws://agent.example/agent_ws/"
|
assert runtime.platform.agent_api.url == "ws://agent.example/agent_ws/"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_matrix_main_closes_platform_without_connecting_root_agent(monkeypatch):
|
||||||
|
bot_module = importlib.import_module("adapter.matrix.bot")
|
||||||
|
|
||||||
|
platform_close = AsyncMock()
|
||||||
|
agent_connect = AsyncMock()
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
platform=SimpleNamespace(
|
||||||
|
close=platform_close,
|
||||||
|
agent_api=SimpleNamespace(connect=agent_connect),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeAsyncClient:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.access_token = None
|
||||||
|
self.callbacks = []
|
||||||
|
self.sync_forever = AsyncMock()
|
||||||
|
self.close = AsyncMock()
|
||||||
|
|
||||||
|
async def login(self, *args, **kwargs):
|
||||||
|
raise AssertionError("login should not be called when access token is provided")
|
||||||
|
|
||||||
|
def add_event_callback(self, callback, event_type):
|
||||||
|
self.callbacks.append((callback, event_type))
|
||||||
|
|
||||||
|
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||||
|
monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org")
|
||||||
|
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "token")
|
||||||
|
monkeypatch.setattr(bot_module, "AsyncClient", FakeAsyncClient)
|
||||||
|
monkeypatch.setattr(bot_module, "build_runtime", lambda **kwargs: runtime)
|
||||||
|
monkeypatch.setattr(bot_module, "prepare_live_sync", AsyncMock(return_value="s123"))
|
||||||
|
|
||||||
|
await bot_module.main()
|
||||||
|
|
||||||
|
agent_connect.assert_not_awaited()
|
||||||
|
platform_close.assert_awaited_once()
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,35 @@ class LegacyAgentApi:
|
||||||
self.last_tokens_used = 7
|
self.last_tokens_used = 7
|
||||||
|
|
||||||
|
|
||||||
|
class BlockingChatAgentApi:
|
||||||
|
def __init__(self, chat_id: str) -> None:
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.calls: list[str] = []
|
||||||
|
self.connect_calls = 0
|
||||||
|
self.close_calls = 0
|
||||||
|
self.last_tokens_used = 0
|
||||||
|
self.active_calls = 0
|
||||||
|
self.max_active_calls = 0
|
||||||
|
self.started = asyncio.Event()
|
||||||
|
self.release = asyncio.Event()
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self.connect_calls += 1
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self.close_calls += 1
|
||||||
|
|
||||||
|
async def send_message(self, text: str):
|
||||||
|
self.calls.append(text)
|
||||||
|
self.active_calls += 1
|
||||||
|
self.max_active_calls = max(self.max_active_calls, self.active_calls)
|
||||||
|
self.started.set()
|
||||||
|
await self.release.wait()
|
||||||
|
self.active_calls -= 1
|
||||||
|
yield FakeChunk(text)
|
||||||
|
self.last_tokens_used = len(text)
|
||||||
|
|
||||||
|
|
||||||
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch):
|
def test_agent_api_wrapper_uses_modern_constructor_when_available(monkeypatch):
|
||||||
calls: list[dict[str, object]] = []
|
calls: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
|
@ -263,6 +292,42 @@ async def test_real_platform_client_creates_distinct_clients_per_chat():
|
||||||
assert agent_api.instances["chat-2"].calls == ["world"]
|
assert agent_api.instances["chat-2"].calls == ["world"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_platform_client_serializes_same_chat_streams_across_send_paths():
|
||||||
|
agent_api = FakeAgentApiFactory()
|
||||||
|
agent_api.instances["chat-1"] = BlockingChatAgentApi("chat-1")
|
||||||
|
agent_api.for_chat = lambda chat_id: agent_api.instances.setdefault(chat_id, BlockingChatAgentApi(chat_id))
|
||||||
|
client = RealPlatformClient(
|
||||||
|
agent_api=agent_api,
|
||||||
|
prototype_state=PrototypeStateStore(),
|
||||||
|
platform="matrix",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def consume_stream():
|
||||||
|
chunks = []
|
||||||
|
async for chunk in client.stream_message("@alice:example.org", "chat-1", "hello"):
|
||||||
|
chunks.append(chunk)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
stream_task = asyncio.create_task(consume_stream())
|
||||||
|
await asyncio.wait_for(agent_api.instances["chat-1"].started.wait(), timeout=1)
|
||||||
|
|
||||||
|
send_task = asyncio.create_task(client.send_message("@alice:example.org", "chat-1", "again"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert agent_api.instances["chat-1"].calls == ["hello"]
|
||||||
|
assert agent_api.instances["chat-1"].max_active_calls == 1
|
||||||
|
|
||||||
|
agent_api.instances["chat-1"].release.set()
|
||||||
|
stream_chunks = await stream_task
|
||||||
|
send_result = await send_task
|
||||||
|
|
||||||
|
assert [chunk.delta for chunk in stream_chunks] == ["hello", ""]
|
||||||
|
assert send_result.response == "again"
|
||||||
|
assert agent_api.instances["chat-1"].calls == ["hello", "again"]
|
||||||
|
assert agent_api.instances["chat-1"].max_active_calls == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
async def test_real_platform_client_stream_message_emits_final_tokens_chunk():
|
||||||
agent_api = FakeAgentApiFactory()
|
agent_api = FakeAgentApiFactory()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue