fix max-bot, add tests
This commit is contained in:
parent
7abbaf7e7a
commit
2ad1438e1c
17 changed files with 1621 additions and 494 deletions
|
|
@ -1,50 +1,141 @@
|
|||
"""Agent registry for MAX surface."""
|
||||
import os
|
||||
import yaml
|
||||
from typing import List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
id: str
|
||||
class AgentRegistryError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentDefinition:
|
||||
agent_id: str
|
||||
label: str
|
||||
base_url: str
|
||||
workspace_path: str
|
||||
base_url: str = field(default="")
|
||||
workspace_path: str = field(default="")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentAssignment:
|
||||
agent_id: str | None
|
||||
source: Literal["configured", "default", "none"]
|
||||
|
||||
@property
|
||||
def is_default(self) -> bool:
|
||||
return self.source == "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRegistry:
|
||||
agents: List[AgentConfig] = field(default_factory=list)
|
||||
"""Same contract as Matrix agent registry: user_agents maps MAX user_id string -> agent_id."""
|
||||
|
||||
def get_agent_for_user(self, user_id: str) -> AgentConfig:
|
||||
return self.agents[0]
|
||||
def __init__(
|
||||
self,
|
||||
agents: list[AgentDefinition],
|
||||
user_agents: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
self.agents = tuple(agents)
|
||||
self._by_id = {agent.agent_id: agent for agent in self.agents}
|
||||
self._user_agents: dict[str, str] = dict(user_agents or {})
|
||||
|
||||
def get_agent_by_id(self, agent_id: str) -> Optional[AgentConfig]:
|
||||
for agent in self.agents:
|
||||
if agent.id == agent_id:
|
||||
return agent
|
||||
return None
|
||||
def get(self, agent_id: str) -> AgentDefinition:
|
||||
try:
|
||||
return self._by_id[agent_id]
|
||||
except KeyError as exc:
|
||||
raise AgentRegistryError(f"unknown agent id: {agent_id}") from exc
|
||||
|
||||
def get_agent_id_for_user(self, max_user_id: str) -> str | None:
|
||||
return self._user_agents.get(max_user_id)
|
||||
|
||||
def resolve_agent_for_user(self, max_user_id: str) -> AgentAssignment:
|
||||
agent_id = self.get_agent_id_for_user(max_user_id)
|
||||
if agent_id is not None:
|
||||
return AgentAssignment(agent_id=agent_id, source="configured")
|
||||
if self.agents:
|
||||
return AgentAssignment(agent_id=self.agents[0].agent_id, source="default")
|
||||
return AgentAssignment(agent_id=None, source="none")
|
||||
|
||||
|
||||
def load_agent_registry(path: str) -> AgentRegistry:
|
||||
with open(path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
def _required_text(entry: Mapping[str, object], key: str) -> str:
|
||||
value = entry.get(key)
|
||||
if not isinstance(value, str):
|
||||
raise AgentRegistryError("each agent entry requires id and label")
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise AgentRegistryError("each agent entry requires id and label")
|
||||
return text
|
||||
|
||||
registry = AgentRegistry()
|
||||
for a in data.get("agents", []):
|
||||
registry.agents.append(AgentConfig(
|
||||
id=a["id"],
|
||||
label=a.get("label", ""),
|
||||
base_url=a["base_url"],
|
||||
workspace_path=a["workspace_path"],
|
||||
))
|
||||
return registry
|
||||
|
||||
def _optional_text(entry: Mapping[str, object], key: str) -> str:
|
||||
value = entry.get(key)
|
||||
if value is None:
|
||||
return ""
|
||||
if not isinstance(value, str):
|
||||
raise AgentRegistryError(f"agent entry field '{key}' must be a string")
|
||||
return value.strip()
|
||||
|
||||
|
||||
def _load_registry_data(path: str | Path) -> dict[str, object]:
|
||||
try:
|
||||
raw = yaml.safe_load(Path(path).read_text(encoding="utf-8"))
|
||||
except yaml.YAMLError as exc:
|
||||
raise AgentRegistryError("invalid agent registry YAML") from exc
|
||||
if raw is None:
|
||||
return {}
|
||||
if not isinstance(raw, Mapping):
|
||||
raise AgentRegistryError("agent registry must be a mapping with an agents list")
|
||||
return dict(raw)
|
||||
|
||||
|
||||
def load_agent_registry(path: str | Path) -> AgentRegistry:
|
||||
raw = _load_registry_data(path)
|
||||
entries = raw.get("agents")
|
||||
if not isinstance(entries, list) or not entries:
|
||||
raise AgentRegistryError("agents registry must contain a non-empty agents list")
|
||||
|
||||
agents: list[AgentDefinition] = []
|
||||
seen: set[str] = set()
|
||||
for entry in entries:
|
||||
if not isinstance(entry, Mapping):
|
||||
raise AgentRegistryError("each agent entry requires id and label")
|
||||
agent_id = _required_text(entry, "id")
|
||||
label = _required_text(entry, "label")
|
||||
base_url = _optional_text(entry, "base_url")
|
||||
workspace_path = _optional_text(entry, "workspace_path")
|
||||
if agent_id in seen:
|
||||
raise AgentRegistryError(f"duplicate agent id: {agent_id}")
|
||||
seen.add(agent_id)
|
||||
agents.append(
|
||||
AgentDefinition(
|
||||
agent_id=agent_id,
|
||||
label=label,
|
||||
base_url=base_url,
|
||||
workspace_path=workspace_path,
|
||||
)
|
||||
)
|
||||
|
||||
user_agents = raw.get("user_agents")
|
||||
if user_agents is not None:
|
||||
if not isinstance(user_agents, Mapping):
|
||||
raise AgentRegistryError("user_agents must be a mapping of user id strings to agent ids")
|
||||
normalized: dict[str, str] = {}
|
||||
for uid, aid in user_agents.items():
|
||||
if not isinstance(uid, str) or not isinstance(aid, str):
|
||||
raise AgentRegistryError("user_agents keys and values must be strings")
|
||||
normalized[uid.strip()] = aid.strip()
|
||||
user_agents_map: Mapping[str, str] = normalized
|
||||
else:
|
||||
user_agents_map = {}
|
||||
|
||||
return AgentRegistry(agents=agents, user_agents=user_agents_map)
|
||||
|
||||
|
||||
def load_from_env() -> AgentRegistry:
|
||||
path = os.environ.get(
|
||||
"MAX_AGENT_REGISTRY_PATH",
|
||||
"/app/config/max-agents.yaml",
|
||||
)
|
||||
return load_agent_registry(path)
|
||||
import os
|
||||
|
||||
path = os.environ.get("MAX_AGENT_REGISTRY_PATH", "/app/config/max-agents.yaml")
|
||||
return load_agent_registry(path)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue