feat: support shared-workspace file flow for matrix

This commit is contained in:
Mikhail Putilovskij 2026-04-21 00:26:21 +03:00
parent 323a6d3144
commit 6422c7db58
18 changed files with 871 additions and 80 deletions

View file

@ -86,6 +86,55 @@ class AgentApiWrapper(AgentApi):
**self._init_kwargs,
)
@staticmethod
def _event_kind(event: object) -> str:
raw_kind = getattr(event, "type", None)
if hasattr(raw_kind, "value"):
raw_kind = raw_kind.value
if raw_kind is None:
raw_kind = event.__class__.__name__
kind = str(raw_kind).replace("-", "_")
if "_" in kind:
return kind.upper()
normalized = []
for index, char in enumerate(kind):
if index and char.isupper() and not kind[index - 1].isupper():
normalized.append("_")
normalized.append(char)
return "".join(normalized).upper()
@classmethod
def _is_kind(cls, event: object, *needles: str) -> bool:
kind = cls._event_kind(event)
return any(needle in kind for needle in needles)
@classmethod
def _is_text_event(cls, event: object) -> bool:
return hasattr(event, "text") or cls._is_kind(event, "TEXT_CHUNK")
@classmethod
def _is_end_event(cls, event: object) -> bool:
kind = cls._event_kind(event)
return kind == "END" or kind.endswith("_END")
@classmethod
def _is_send_file_event(cls, event: object) -> bool:
return "SEND_FILE" in cls._event_kind(event)
async def _publish_event(self, event: object, *, queue_event: object | None = None) -> None:
if self.callback:
self.callback(event)
if self._current_queue:
await self._current_queue.put(queue_event if queue_event is not None else event)
async def _publish_error(self, event: object) -> None:
if self.callback:
self.callback(event)
if self._current_queue and hasattr(event, "code") and hasattr(event, "details"):
await self._current_queue.put(AgentException(getattr(event, "code"), getattr(event, "details")))
async def _listen(self):
try:
async for msg in self._ws:
@ -93,7 +142,7 @@ class AgentApiWrapper(AgentApi):
try:
outgoing_msg = ServerMessage.validate_json(msg.data)
if isinstance(outgoing_msg, MsgEventTextChunk):
if self._is_text_event(outgoing_msg):
if self._current_queue:
await self._current_queue.put(outgoing_msg)
elif self.callback:
@ -101,29 +150,22 @@ class AgentApiWrapper(AgentApi):
else:
logger.warning("[%s] AgentEvent without active request", self.id)
elif isinstance(outgoing_msg, MsgEventEnd):
elif self._is_end_event(outgoing_msg):
self.last_tokens_used = outgoing_msg.tokens_used
if self._current_queue:
await self._current_queue.put(outgoing_msg)
await self._publish_event(outgoing_msg)
elif isinstance(outgoing_msg, MsgError):
if self.callback:
self.callback(outgoing_msg)
elif self._is_kind(outgoing_msg, "ERROR"):
error = AgentException(outgoing_msg.code, outgoing_msg.details)
logger.error("[%s] Agent error: %s", self.id, error)
if self._current_queue:
await self._current_queue.put(error)
await self._publish_error(outgoing_msg)
elif isinstance(outgoing_msg, MsgGracefulDisconnect):
if self.callback:
self.callback(outgoing_msg)
elif self._is_kind(outgoing_msg, "GRACEFUL_DISCONNECT"):
await self._publish_event(outgoing_msg)
logger.info("[%s] Gracefully disconnecting", self.id)
break
else:
logger.warning("[%s] Unknown message type: %s", self.id, outgoing_msg.type)
if self.callback:
self.callback(outgoing_msg)
await self._publish_event(outgoing_msg)
except Exception as exc:
logger.error("[%s] Failed to deserialize message: %s", self.id, exc)