Merge pull request #2361 from NousResearch/hermes/hermes-5d6932ba

feat(gateway): cache AIAgent per session for prompt caching
This commit is contained in:
Teknium 2026-03-21 16:53:21 -07:00 committed by GitHub
commit 52dd479214
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 348 additions and 28 deletions

View file

@ -344,6 +344,15 @@ class GatewayRunner:
self._running_agents: Dict[str, Any] = {}
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
# Cache AIAgent instances per session to preserve prompt caching.
# Without this, a new AIAgent is created per message, rebuilding the
# system prompt (including memory) every turn — breaking prefix cache
# and costing ~10x more on providers with prompt caching (Anthropic).
# Key: session_key, Value: (AIAgent, config_signature_str)
import threading as _threading
self._agent_cache: Dict[str, tuple] = {}
self._agent_cache_lock = _threading.Lock()
# Track active fallback model/provider when primary is rate-limited.
# Set after an agent run where fallback was activated; cleared when
# the primary model succeeds again or the user switches via /model.
@ -2386,6 +2395,7 @@ class GatewayRunner:
logger.debug("Gateway memory flush on reset failed: %s", e)
self._shutdown_gateway_honcho(session_key)
self._evict_cached_agent(session_key)
# Reset the session
new_entry = self.session_store.reset_session(session_key)
@ -4501,6 +4511,45 @@ class GatewayRunner:
_MAX_INTERRUPT_DEPTH = 3 # Cap recursive interrupt handling (#816)
@staticmethod
def _agent_config_signature(
model: str,
runtime: dict,
enabled_toolsets: list,
ephemeral_prompt: str,
) -> str:
"""Compute a stable string key from agent config values.
When this signature changes between messages, the cached AIAgent is
discarded and rebuilt. When it stays the same, the cached agent is
reused preserving the frozen system prompt and tool schemas for
prompt cache hits.
"""
import hashlib, json as _j
blob = _j.dumps(
[
model,
runtime.get("api_key", "")[:8], # first 8 chars only
runtime.get("base_url", ""),
runtime.get("provider", ""),
runtime.get("api_mode", ""),
sorted(enabled_toolsets) if enabled_toolsets else [],
# reasoning_config excluded — it's set per-message on the
# cached agent and doesn't affect system prompt or tools.
ephemeral_prompt or "",
],
sort_keys=True,
default=str,
)
return hashlib.sha256(blob.encode()).hexdigest()[:16]
def _evict_cached_agent(self, session_key: str) -> None:
"""Remove a cached agent for a session (called on /new, /model, etc)."""
_lock = getattr(self, "_agent_cache_lock", None)
if _lock:
with _lock:
self._agent_cache.pop(session_key, None)
async def _run_agent(
self,
message: str,
@ -4850,34 +4899,64 @@ class GatewayRunner:
logger.debug("Could not set up stream consumer: %s", _sc_err)
turn_route = self._resolve_turn_agent_config(message, model, runtime_kwargs)
agent = AIAgent(
model=turn_route["model"],
**turn_route["runtime"],
max_iterations=max_iterations,
quiet_mode=True,
verbose_logging=False,
enabled_toolsets=enabled_toolsets,
ephemeral_system_prompt=combined_ephemeral or None,
prefill_messages=self._prefill_messages or None,
reasoning_config=reasoning_config,
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
provider_sort=pr.get("sort"),
provider_require_parameters=pr.get("require_parameters", False),
provider_data_collection=pr.get("data_collection"),
session_id=session_id,
tool_progress_callback=progress_callback if tool_progress_enabled else None,
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
stream_delta_callback=_stream_delta_cb,
status_callback=_status_callback_sync,
platform=platform_key,
honcho_session_key=session_key,
honcho_manager=honcho_manager,
honcho_config=honcho_config,
session_db=self._session_db,
fallback_model=self._fallback_model,
# Check agent cache — reuse the AIAgent from the previous message
# in this session to preserve the frozen system prompt and tool
# schemas for prompt cache hits.
_sig = self._agent_config_signature(
turn_route["model"],
turn_route["runtime"],
enabled_toolsets,
combined_ephemeral,
)
agent = None
_cache_lock = getattr(self, "_agent_cache_lock", None)
_cache = getattr(self, "_agent_cache", None)
if _cache_lock and _cache is not None:
with _cache_lock:
cached = _cache.get(session_key)
if cached and cached[1] == _sig:
agent = cached[0]
logger.debug("Reusing cached agent for session %s", session_key)
if agent is None:
# Config changed or first message — create fresh agent
agent = AIAgent(
model=turn_route["model"],
**turn_route["runtime"],
max_iterations=max_iterations,
quiet_mode=True,
verbose_logging=False,
enabled_toolsets=enabled_toolsets,
ephemeral_system_prompt=combined_ephemeral or None,
prefill_messages=self._prefill_messages or None,
reasoning_config=reasoning_config,
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
provider_sort=pr.get("sort"),
provider_require_parameters=pr.get("require_parameters", False),
provider_data_collection=pr.get("data_collection"),
session_id=session_id,
platform=platform_key,
honcho_session_key=session_key,
honcho_manager=honcho_manager,
honcho_config=honcho_config,
session_db=self._session_db,
fallback_model=self._fallback_model,
)
if _cache_lock and _cache is not None:
with _cache_lock:
_cache[session_key] = (agent, _sig)
logger.debug("Created new agent for session %s (sig=%s)", session_key, _sig)
# Per-message state — callbacks and reasoning config change every
# turn and must not be baked into the cached agent constructor.
agent.tool_progress_callback = progress_callback if tool_progress_enabled else None
agent.step_callback = _step_callback_sync if _hooks_ref.loaded_hooks else None
agent.stream_delta_callback = _stream_delta_cb
agent.status_callback = _status_callback_sync
agent.reasoning_config = reasoning_config
# Store agent reference for interrupt support
agent_holder[0] = agent
@ -5122,6 +5201,9 @@ class GatewayRunner:
if _agent.model != _cfg_model:
self._effective_model = _agent.model
self._effective_provider = getattr(_agent, 'provider', None)
# Fallback activated — evict cached agent so the next
# message starts fresh and retries the primary model.
self._evict_cached_agent(session_key)
else:
# Primary model worked — clear any stale fallback state
self._effective_model = None