Merge pull request #1327 from NousResearch/hermes/hermes-048e6599
Merging the non-redundant fixes salvaged from #993 onto current main, plus adjacent trajectory compressor hardening found during review.
This commit is contained in:
commit
6d8286f396
6 changed files with 84 additions and 17 deletions
|
|
@ -39,7 +39,9 @@ def resize_tool_pool(max_workers: int):
|
||||||
Safe to call before any tasks are submitted.
|
Safe to call before any tasks are submitted.
|
||||||
"""
|
"""
|
||||||
global _tool_executor
|
global _tool_executor
|
||||||
|
old_executor = _tool_executor
|
||||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
old_executor.shutdown(wait=False)
|
||||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,7 @@ class DeliveryRouter:
|
||||||
|
|
||||||
# Always include local if configured
|
# Always include local if configured
|
||||||
if self.config.always_log_local:
|
if self.config.always_log_local:
|
||||||
local_key = (Platform.LOCAL, None)
|
local_key = (Platform.LOCAL, None, None)
|
||||||
if local_key not in seen_platforms:
|
if local_key not in seen_platforms:
|
||||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Tests for the delivery routing module."""
|
"""Tests for the delivery routing module."""
|
||||||
|
|
||||||
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
|
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
|
||||||
from gateway.delivery import DeliveryTarget, parse_deliver_spec
|
from gateway.delivery import DeliveryRouter, DeliveryTarget, parse_deliver_spec
|
||||||
from gateway.session import SessionSource
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -85,3 +85,12 @@ class TestTargetToStringRoundtrip:
|
||||||
reparsed = DeliveryTarget.parse(s)
|
reparsed = DeliveryTarget.parse(s)
|
||||||
assert reparsed.platform == Platform.TELEGRAM
|
assert reparsed.platform == Platform.TELEGRAM
|
||||||
assert reparsed.chat_id == "999"
|
assert reparsed.chat_id == "999"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeliveryRouter:
|
||||||
|
def test_resolve_targets_does_not_duplicate_local_when_explicit(self):
|
||||||
|
router = DeliveryRouter(GatewayConfig(always_log_local=True))
|
||||||
|
|
||||||
|
targets = router.resolve_targets(["local"])
|
||||||
|
|
||||||
|
assert [target.platform for target in targets] == [Platform.LOCAL]
|
||||||
|
|
|
||||||
|
|
@ -484,3 +484,22 @@ class TestResizeToolPool:
|
||||||
"""resize_tool_pool should not raise."""
|
"""resize_tool_pool should not raise."""
|
||||||
resize_tool_pool(16) # Small pool for testing
|
resize_tool_pool(16) # Small pool for testing
|
||||||
resize_tool_pool(128) # Restore default
|
resize_tool_pool(128) # Restore default
|
||||||
|
|
||||||
|
def test_resize_shuts_down_previous_executor(self, monkeypatch):
|
||||||
|
"""Replacing the global tool executor should shut down the old pool."""
|
||||||
|
import environments.agent_loop as agent_loop_module
|
||||||
|
|
||||||
|
old_executor = MagicMock()
|
||||||
|
new_executor = MagicMock()
|
||||||
|
|
||||||
|
monkeypatch.setattr(agent_loop_module, "_tool_executor", old_executor)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
agent_loop_module.concurrent.futures,
|
||||||
|
"ThreadPoolExecutor",
|
||||||
|
MagicMock(return_value=new_executor),
|
||||||
|
)
|
||||||
|
|
||||||
|
resize_tool_pool(16)
|
||||||
|
|
||||||
|
old_executor.shutdown.assert_called_once_with(wait=False)
|
||||||
|
assert agent_loop_module._tool_executor is new_executor
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
"""Tests for trajectory_compressor.py — config, metrics, and compression logic."""
|
"""Tests for trajectory_compressor.py — config, metrics, and compression logic."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import patch, MagicMock
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from trajectory_compressor import (
|
from trajectory_compressor import (
|
||||||
CompressionConfig,
|
CompressionConfig,
|
||||||
|
|
@ -384,3 +387,32 @@ class TestTokenCounting:
|
||||||
tc.tokenizer.encode = MagicMock(side_effect=Exception("fail"))
|
tc.tokenizer.encode = MagicMock(side_effect=Exception("fail"))
|
||||||
# Should fallback to len(text) // 4
|
# Should fallback to len(text) // 4
|
||||||
assert tc.count_tokens("12345678") == 2
|
assert tc.count_tokens("12345678") == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateSummary:
|
||||||
|
def test_generate_summary_handles_none_content(self):
|
||||||
|
tc = _make_compressor()
|
||||||
|
tc.client = MagicMock()
|
||||||
|
tc.client.chat.completions.create.return_value = SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=SimpleNamespace(content=None))]
|
||||||
|
)
|
||||||
|
metrics = TrajectoryMetrics()
|
||||||
|
|
||||||
|
summary = tc._generate_summary("Turn content", metrics)
|
||||||
|
|
||||||
|
assert summary == "[CONTEXT SUMMARY]:"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_summary_async_handles_none_content(self):
|
||||||
|
tc = _make_compressor()
|
||||||
|
tc.async_client = MagicMock()
|
||||||
|
tc.async_client.chat.completions.create = AsyncMock(
|
||||||
|
return_value=SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=SimpleNamespace(content=None))]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metrics = TrajectoryMetrics()
|
||||||
|
|
||||||
|
summary = await tc._generate_summary_async("Turn content", metrics)
|
||||||
|
|
||||||
|
assert summary == "[CONTEXT SUMMARY]:"
|
||||||
|
|
|
||||||
|
|
@ -496,6 +496,21 @@ class TrajectoryCompressor:
|
||||||
|
|
||||||
return "\n\n".join(parts)
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_summary_content(content: Any) -> str:
|
||||||
|
"""Normalize summary-model output to a safe string."""
|
||||||
|
if not isinstance(content, str):
|
||||||
|
content = str(content) if content else ""
|
||||||
|
return content.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_summary_prefix(summary: str) -> str:
|
||||||
|
"""Normalize summary text to include the expected prefix exactly once."""
|
||||||
|
text = (summary or "").strip()
|
||||||
|
if text.startswith("[CONTEXT SUMMARY]:"):
|
||||||
|
return text
|
||||||
|
return "[CONTEXT SUMMARY]:" if not text else f"[CONTEXT SUMMARY]: {text}"
|
||||||
|
|
||||||
def _generate_summary(self, content: str, metrics: TrajectoryMetrics) -> str:
|
def _generate_summary(self, content: str, metrics: TrajectoryMetrics) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a summary of the compressed turns using OpenRouter.
|
Generate a summary of the compressed turns using OpenRouter.
|
||||||
|
|
@ -545,13 +560,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||||
max_tokens=self.config.summary_target_tokens * 2,
|
max_tokens=self.config.summary_target_tokens * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
summary = response.choices[0].message.content.strip()
|
summary = self._coerce_summary_content(response.choices[0].message.content)
|
||||||
|
return self._ensure_summary_prefix(summary)
|
||||||
# Ensure it starts with the prefix
|
|
||||||
if not summary.startswith("[CONTEXT SUMMARY]:"):
|
|
||||||
summary = "[CONTEXT SUMMARY]: " + summary
|
|
||||||
|
|
||||||
return summary
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
metrics.summarization_errors += 1
|
metrics.summarization_errors += 1
|
||||||
|
|
@ -612,13 +622,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||||
max_tokens=self.config.summary_target_tokens * 2,
|
max_tokens=self.config.summary_target_tokens * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
summary = response.choices[0].message.content.strip()
|
summary = self._coerce_summary_content(response.choices[0].message.content)
|
||||||
|
return self._ensure_summary_prefix(summary)
|
||||||
# Ensure it starts with the prefix
|
|
||||||
if not summary.startswith("[CONTEXT SUMMARY]:"):
|
|
||||||
summary = "[CONTEXT SUMMARY]: " + summary
|
|
||||||
|
|
||||||
return summary
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
metrics.summarization_errors += 1
|
metrics.summarization_errors += 1
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue