Merge pull request #2091 from dusterbloom/fix/lmstudio-context-length-detection
feat: query local servers for actual context window size
This commit is contained in:
commit
3a9a1bbb84
3 changed files with 742 additions and 9 deletions
|
|
@ -146,6 +146,9 @@ _MAX_COMPLETION_KEYS = (
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Local server hostnames / address patterns
|
||||||
|
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
||||||
|
|
||||||
|
|
||||||
def _normalize_base_url(base_url: str) -> str:
|
def _normalize_base_url(base_url: str) -> str:
|
||||||
return (base_url or "").strip().rstrip("/")
|
return (base_url or "").strip().rstrip("/")
|
||||||
|
|
@ -178,6 +181,99 @@ def _is_known_provider_base_url(base_url: str) -> bool:
|
||||||
return any(known_host in host for known_host in known_hosts)
|
return any(known_host in host for known_host in known_hosts)
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_endpoint(base_url: str) -> bool:
|
||||||
|
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
|
||||||
|
normalized = _normalize_base_url(base_url)
|
||||||
|
if not normalized:
|
||||||
|
return False
|
||||||
|
url = normalized if "://" in normalized else f"http://{normalized}"
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname or ""
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
if host in _LOCAL_HOSTS:
|
||||||
|
return True
|
||||||
|
# RFC-1918 private ranges and link-local
|
||||||
|
import ipaddress
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(host)
|
||||||
|
return addr.is_private or addr.is_loopback or addr.is_link_local
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
|
||||||
|
parts = host.split(".")
|
||||||
|
if len(parts) == 4:
|
||||||
|
try:
|
||||||
|
first, second = int(parts[0]), int(parts[1])
|
||||||
|
if first == 10:
|
||||||
|
return True
|
||||||
|
if first == 172 and 16 <= second <= 31:
|
||||||
|
return True
|
||||||
|
if first == 192 and second == 168:
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||||
|
"""Detect which local server is running at base_url by probing known endpoints.
|
||||||
|
|
||||||
|
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
normalized = _normalize_base_url(base_url)
|
||||||
|
server_url = normalized
|
||||||
|
if server_url.endswith("/v1"):
|
||||||
|
server_url = server_url[:-3]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=2.0) as client:
|
||||||
|
# LM Studio exposes /api/v1/models — check first (most specific)
|
||||||
|
try:
|
||||||
|
r = client.get(f"{server_url}/api/v1/models")
|
||||||
|
if r.status_code == 200:
|
||||||
|
return "lm-studio"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Ollama exposes /api/tags and responds with {"models": [...]}
|
||||||
|
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
|
||||||
|
# on this path, so we must verify the response contains "models".
|
||||||
|
try:
|
||||||
|
r = client.get(f"{server_url}/api/tags")
|
||||||
|
if r.status_code == 200:
|
||||||
|
try:
|
||||||
|
data = r.json()
|
||||||
|
if "models" in data:
|
||||||
|
return "ollama"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# llama.cpp exposes /props
|
||||||
|
try:
|
||||||
|
r = client.get(f"{server_url}/props")
|
||||||
|
if r.status_code == 200 and "default_generation_settings" in r.text:
|
||||||
|
return "llamacpp"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# vLLM: /version
|
||||||
|
try:
|
||||||
|
r = client.get(f"{server_url}/version")
|
||||||
|
if r.status_code == 200:
|
||||||
|
data = r.json()
|
||||||
|
if "version" in data:
|
||||||
|
return "vllm"
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _iter_nested_dicts(value: Any):
|
def _iter_nested_dicts(value: Any):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
yield value
|
yield value
|
||||||
|
|
@ -383,7 +479,7 @@ def _get_context_cache_path() -> Path:
|
||||||
|
|
||||||
|
|
||||||
def _load_context_cache() -> Dict[str, int]:
|
def _load_context_cache() -> Dict[str, int]:
|
||||||
"""Load the model+provider → context_length cache from disk."""
|
"""Load the model+provider -> context_length cache from disk."""
|
||||||
path = _get_context_cache_path()
|
path = _get_context_cache_path()
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -412,7 +508,7 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
||||||
logger.info("Cached context length %s → %s tokens", key, f"{length:,}")
|
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Failed to save context length cache: %s", e)
|
logger.debug("Failed to save context length cache: %s", e)
|
||||||
|
|
||||||
|
|
@ -460,6 +556,116 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||||
|
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
||||||
|
|
||||||
|
Supports two forms:
|
||||||
|
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
|
||||||
|
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
|
||||||
|
(the part after the last "/" equals lookup_model)
|
||||||
|
|
||||||
|
This covers LM Studio's native API which stores models as "publisher/slug"
|
||||||
|
while users typically configure only the slug after the "local:" prefix.
|
||||||
|
"""
|
||||||
|
if candidate_id == lookup_model:
|
||||||
|
return True
|
||||||
|
# Slug match: basename of candidate equals the lookup name
|
||||||
|
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||||
|
"""Query a local server for the model's context length."""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# Strip provider prefix (e.g., "local:model-name" → "model-name").
|
||||||
|
# LM Studio and Ollama don't use provider prefixes in their model IDs.
|
||||||
|
if ":" in model and not model.startswith("http"):
|
||||||
|
model = model.split(":", 1)[1]
|
||||||
|
|
||||||
|
# Strip /v1 suffix to get the server root
|
||||||
|
server_url = base_url.rstrip("/")
|
||||||
|
if server_url.endswith("/v1"):
|
||||||
|
server_url = server_url[:-3]
|
||||||
|
|
||||||
|
try:
|
||||||
|
server_type = detect_local_server_type(base_url)
|
||||||
|
except Exception:
|
||||||
|
server_type = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=3.0) as client:
|
||||||
|
# Ollama: /api/show returns model details with context info
|
||||||
|
if server_type == "ollama":
|
||||||
|
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
# Check model_info for context length
|
||||||
|
model_info = data.get("model_info", {})
|
||||||
|
for key, value in model_info.items():
|
||||||
|
if "context_length" in key and isinstance(value, (int, float)):
|
||||||
|
return int(value)
|
||||||
|
# Check parameters string for num_ctx
|
||||||
|
params = data.get("parameters", "")
|
||||||
|
if "num_ctx" in params:
|
||||||
|
for line in params.split("\n"):
|
||||||
|
if "num_ctx" in line:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
try:
|
||||||
|
return int(parts[-1])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# LM Studio native API: /api/v1/models returns max_context_length.
|
||||||
|
# This is more reliable than the OpenAI-compat /v1/models which
|
||||||
|
# doesn't include context window information for LM Studio servers.
|
||||||
|
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
|
||||||
|
# "publisher/slug" but users configure only "slug" after "local:" prefix.
|
||||||
|
if server_type == "lm-studio":
|
||||||
|
resp = client.get(f"{server_url}/api/v1/models")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
for m in data.get("models", []):
|
||||||
|
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
|
||||||
|
# Prefer loaded instance context (actual runtime value)
|
||||||
|
for inst in m.get("loaded_instances", []):
|
||||||
|
cfg = inst.get("config", {})
|
||||||
|
ctx = cfg.get("context_length")
|
||||||
|
if ctx and isinstance(ctx, (int, float)):
|
||||||
|
return int(ctx)
|
||||||
|
# Fall back to max_context_length (theoretical model max)
|
||||||
|
ctx = m.get("max_context_length") or m.get("context_length")
|
||||||
|
if ctx and isinstance(ctx, (int, float)):
|
||||||
|
return int(ctx)
|
||||||
|
|
||||||
|
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
|
||||||
|
resp = client.get(f"{server_url}/v1/models/{model}")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
# vLLM returns max_model_len
|
||||||
|
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
|
||||||
|
if ctx and isinstance(ctx, (int, float)):
|
||||||
|
return int(ctx)
|
||||||
|
|
||||||
|
# Try /v1/models and find the model in the list.
|
||||||
|
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
|
||||||
|
resp = client.get(f"{server_url}/v1/models")
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
models_list = data.get("data", [])
|
||||||
|
for m in models_list:
|
||||||
|
if _model_id_matches(m.get("id", ""), model):
|
||||||
|
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
|
||||||
|
if ctx and isinstance(ctx, (int, float)):
|
||||||
|
return int(ctx)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_model_context_length(
|
def get_model_context_length(
|
||||||
model: str,
|
model: str,
|
||||||
base_url: str = "",
|
base_url: str = "",
|
||||||
|
|
@ -472,14 +678,21 @@ def get_model_context_length(
|
||||||
0. Explicit config override (model.context_length in config.yaml)
|
0. Explicit config override (model.context_length in config.yaml)
|
||||||
1. Persistent cache (previously discovered via probing)
|
1. Persistent cache (previously discovered via probing)
|
||||||
2. Active endpoint metadata (/models for explicit custom endpoints)
|
2. Active endpoint metadata (/models for explicit custom endpoints)
|
||||||
3. OpenRouter API metadata
|
3. Local server query (for local endpoints when model not in /models list)
|
||||||
4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
4. OpenRouter API metadata
|
||||||
5. First probe tier (2M) — will be narrowed on first context error
|
5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
||||||
|
6. First probe tier (2M) — will be narrowed on first context error
|
||||||
"""
|
"""
|
||||||
# 0. Explicit config override — user knows best
|
# 0. Explicit config override — user knows best
|
||||||
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
||||||
return config_context_length
|
return config_context_length
|
||||||
|
|
||||||
|
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
||||||
|
# "model-name") so cache lookups and server queries use the bare ID that
|
||||||
|
# local servers actually know about.
|
||||||
|
if ":" in model and not model.startswith("http"):
|
||||||
|
model = model.split(":", 1)[1]
|
||||||
|
|
||||||
# 1. Check persistent cache (model+provider)
|
# 1. Check persistent cache (model+provider)
|
||||||
if base_url:
|
if base_url:
|
||||||
cached = get_cached_context_length(model, base_url)
|
cached = get_cached_context_length(model, base_url)
|
||||||
|
|
@ -507,6 +720,12 @@ def get_model_context_length(
|
||||||
if not _is_known_provider_base_url(base_url):
|
if not _is_known_provider_base_url(base_url):
|
||||||
# Explicit third-party endpoints should not borrow fuzzy global
|
# Explicit third-party endpoints should not borrow fuzzy global
|
||||||
# defaults from unrelated providers with similarly named models.
|
# defaults from unrelated providers with similarly named models.
|
||||||
|
# But first try querying the local server directly.
|
||||||
|
if is_local_endpoint(base_url):
|
||||||
|
local_ctx = _query_local_context_length(model, base_url)
|
||||||
|
if local_ctx and local_ctx > 0:
|
||||||
|
save_context_length(model, base_url, local_ctx)
|
||||||
|
return local_ctx
|
||||||
logger.info(
|
logger.info(
|
||||||
"Could not detect context length for model %r at %s — "
|
"Could not detect context length for model %r at %s — "
|
||||||
"defaulting to %s tokens (probe-down). Set model.context_length "
|
"defaulting to %s tokens (probe-down). Set model.context_length "
|
||||||
|
|
@ -527,7 +746,14 @@ def get_model_context_length(
|
||||||
if default_model in model or model in default_model:
|
if default_model in model or model in default_model:
|
||||||
return length
|
return length
|
||||||
|
|
||||||
# 5. Unknown model — start at highest probe tier
|
# 5. Query local server for unknown models before defaulting to 2M
|
||||||
|
if base_url and is_local_endpoint(base_url):
|
||||||
|
local_ctx = _query_local_context_length(model, base_url)
|
||||||
|
if local_ctx and local_ctx > 0:
|
||||||
|
save_context_length(model, base_url, local_ctx)
|
||||||
|
return local_ctx
|
||||||
|
|
||||||
|
# 6. Unknown model — start at highest probe tier
|
||||||
return CONTEXT_PROBE_TIERS[0]
|
return CONTEXT_PROBE_TIERS[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
16
run_agent.py
16
run_agent.py
|
|
@ -6569,7 +6569,21 @@ class AIAgent:
|
||||||
self._response_was_previewed = True
|
self._response_was_previewed = True
|
||||||
break
|
break
|
||||||
|
|
||||||
# No fallback -- append the empty message as-is
|
# No fallback -- if reasoning_text exists, the model put its
|
||||||
|
# entire response inside <think> tags; use that as the content.
|
||||||
|
if reasoning_text:
|
||||||
|
self._vprint(f"{self.log_prefix}Using reasoning as response content (model wrapped entire response in think tags).", force=True)
|
||||||
|
final_response = reasoning_text
|
||||||
|
empty_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": final_response,
|
||||||
|
"reasoning": reasoning_text,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
messages.append(empty_msg)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Truly empty -- no reasoning and no content
|
||||||
empty_msg = {
|
empty_msg = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": final_response,
|
"content": final_response,
|
||||||
|
|
|
||||||
493
tests/test_model_metadata_local_ctx.py
Normal file
493
tests/test_model_metadata_local_ctx.py
Normal file
|
|
@ -0,0 +1,493 @@
|
||||||
|
"""Tests for _query_local_context_length and the local server fallback in
|
||||||
|
get_model_context_length.
|
||||||
|
|
||||||
|
All tests use synthetic inputs — no filesystem or live server required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _query_local_context_length — unit tests with mocked httpx
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestQueryLocalContextLengthOllama:
|
||||||
|
"""_query_local_context_length with server_type == 'ollama'."""
|
||||||
|
|
||||||
|
def _make_resp(self, status_code, body):
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = status_code
|
||||||
|
resp.json.return_value = body
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def test_ollama_model_info_context_length(self):
|
||||||
|
"""Reads context length from model_info dict in /api/show response."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
show_resp = self._make_resp(200, {
|
||||||
|
"model_info": {"llama.context_length": 131072}
|
||||||
|
})
|
||||||
|
models_resp = self._make_resp(404, {})
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = show_resp
|
||||||
|
client_mock.get.return_value = models_resp
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result == 131072
|
||||||
|
|
||||||
|
def test_ollama_parameters_num_ctx(self):
|
||||||
|
"""Falls back to num_ctx in parameters string when model_info lacks context_length."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
show_resp = self._make_resp(200, {
|
||||||
|
"model_info": {},
|
||||||
|
"parameters": "num_ctx 32768\ntemperature 0.7\n"
|
||||||
|
})
|
||||||
|
models_resp = self._make_resp(404, {})
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = show_resp
|
||||||
|
client_mock.get.return_value = models_resp
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result == 32768
|
||||||
|
|
||||||
|
def test_ollama_show_404_falls_through(self):
|
||||||
|
"""When /api/show returns 404, falls through to /v1/models/{model}."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
show_resp = self._make_resp(404, {})
|
||||||
|
model_detail_resp = self._make_resp(200, {"max_model_len": 65536})
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = show_resp
|
||||||
|
client_mock.get.return_value = model_detail_resp
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result == 65536
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryLocalContextLengthVllm:
|
||||||
|
"""_query_local_context_length with vLLM-style /v1/models/{model} response."""
|
||||||
|
|
||||||
|
def _make_resp(self, status_code, body):
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = status_code
|
||||||
|
resp.json.return_value = body
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def test_vllm_max_model_len(self):
|
||||||
|
"""Reads max_model_len from /v1/models/{model} response."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000})
|
||||||
|
list_resp = self._make_resp(404, {})
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = self._make_resp(404, {})
|
||||||
|
client_mock.get.return_value = detail_resp
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1")
|
||||||
|
|
||||||
|
assert result == 100000
|
||||||
|
|
||||||
|
def test_vllm_context_length_key(self):
|
||||||
|
"""Reads context_length from /v1/models/{model} response."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768})
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = self._make_resp(404, {})
|
||||||
|
client_mock.get.return_value = detail_resp
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("some-model", "http://localhost:8000/v1")
|
||||||
|
|
||||||
|
assert result == 32768
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryLocalContextLengthModelsList:
|
||||||
|
"""_query_local_context_length: falls back to /v1/models list."""
|
||||||
|
|
||||||
|
def _make_resp(self, status_code, body):
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = status_code
|
||||||
|
resp.json.return_value = body
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def test_models_list_max_model_len(self):
|
||||||
|
"""Finds context length for model in /v1/models list."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
detail_resp = self._make_resp(404, {})
|
||||||
|
list_resp = self._make_resp(200, {
|
||||||
|
"data": [
|
||||||
|
{"id": "other-model", "max_model_len": 4096},
|
||||||
|
{"id": "omnicoder-9b", "max_model_len": 131072},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
def side_effect(url, **kwargs):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
return detail_resp # /v1/models/omnicoder-9b
|
||||||
|
return list_resp # /v1/models
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = self._make_resp(404, {})
|
||||||
|
client_mock.get.side_effect = side_effect
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
|
||||||
|
|
||||||
|
assert result == 131072
|
||||||
|
|
||||||
|
def test_models_list_model_not_found_returns_none(self):
|
||||||
|
"""Returns None when model is not in the /v1/models list."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
detail_resp = self._make_resp(404, {})
|
||||||
|
list_resp = self._make_resp(200, {
|
||||||
|
"data": [{"id": "other-model", "max_model_len": 4096}]
|
||||||
|
})
|
||||||
|
|
||||||
|
call_count = [0]
|
||||||
|
def side_effect(url, **kwargs):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] == 1:
|
||||||
|
return detail_resp
|
||||||
|
return list_resp
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = self._make_resp(404, {})
|
||||||
|
client_mock.get.side_effect = side_effect
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryLocalContextLengthLmStudio:
|
||||||
|
"""_query_local_context_length with LM Studio native /api/v1/models response."""
|
||||||
|
|
||||||
|
def _make_resp(self, status_code, body):
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = status_code
|
||||||
|
resp.json.return_value = body
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def _make_client(self, native_resp, detail_resp, list_resp):
|
||||||
|
"""Build a mock httpx.Client with sequenced GET responses."""
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.return_value = self._make_resp(404, {})
|
||||||
|
|
||||||
|
responses = [native_resp, detail_resp, list_resp]
|
||||||
|
call_idx = [0]
|
||||||
|
|
||||||
|
def get_side_effect(url, **kwargs):
|
||||||
|
idx = call_idx[0]
|
||||||
|
call_idx[0] += 1
|
||||||
|
if idx < len(responses):
|
||||||
|
return responses[idx]
|
||||||
|
return self._make_resp(404, {})
|
||||||
|
|
||||||
|
client_mock.get.side_effect = get_side_effect
|
||||||
|
return client_mock
|
||||||
|
|
||||||
|
def test_lmstudio_exact_key_match(self):
|
||||||
|
"""Reads max_context_length when key matches exactly."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
native_resp = self._make_resp(200, {
|
||||||
|
"models": [
|
||||||
|
{"key": "nvidia/nvidia-nemotron-super-49b-v1", "id": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||||
|
"max_context_length": 131072},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
client_mock = self._make_client(
|
||||||
|
native_resp,
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length(
|
||||||
|
"nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 131072
|
||||||
|
|
||||||
|
def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self):
|
||||||
|
"""Fuzzy match: bare model slug matches key that includes publisher prefix.
|
||||||
|
|
||||||
|
When the user configures the model as "local:nvidia-nemotron-super-49b-v1"
|
||||||
|
(slug only, no publisher), but LM Studio's native API stores it as
|
||||||
|
"nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed.
|
||||||
|
"""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
native_resp = self._make_resp(200, {
|
||||||
|
"models": [
|
||||||
|
{"key": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||||
|
"id": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||||
|
"max_context_length": 131072},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
client_mock = self._make_client(
|
||||||
|
native_resp,
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
# Model passed in is just the slug after stripping "local:" prefix
|
||||||
|
result = _query_local_context_length(
|
||||||
|
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 131072
|
||||||
|
|
||||||
|
def test_lmstudio_v1_models_list_slug_fuzzy_match(self):
|
||||||
|
"""Fuzzy match also works for /v1/models list when exact match fails.
|
||||||
|
|
||||||
|
LM Studio's OpenAI-compat /v1/models returns id like
|
||||||
|
"nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug.
|
||||||
|
"""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
# native /api/v1/models: no match
|
||||||
|
native_resp = self._make_resp(404, {})
|
||||||
|
# /v1/models/{model}: no match
|
||||||
|
detail_resp = self._make_resp(404, {})
|
||||||
|
# /v1/models list: model found with publisher prefix, includes context_length
|
||||||
|
list_resp = self._make_resp(200, {
|
||||||
|
"data": [
|
||||||
|
{"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
client_mock = self._make_client(native_resp, detail_resp, list_resp)
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length(
|
||||||
|
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 131072
|
||||||
|
|
||||||
|
def test_lmstudio_loaded_instances_context_length(self):
|
||||||
|
"""Reads active context_length from loaded_instances when max_context_length absent."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
native_resp = self._make_resp(200, {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"key": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||||
|
"id": "nvidia/nvidia-nemotron-super-49b-v1",
|
||||||
|
"loaded_instances": [
|
||||||
|
{"config": {"context_length": 65536}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
client_mock = self._make_client(
|
||||||
|
native_resp,
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length(
|
||||||
|
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 65536
|
||||||
|
|
||||||
|
def test_lmstudio_loaded_instance_beats_max_context_length(self):
|
||||||
|
"""loaded_instances context_length takes priority over max_context_length.
|
||||||
|
|
||||||
|
LM Studio may show max_context_length=1_048_576 (theoretical model max)
|
||||||
|
while the actual loaded context is 122_651 (runtime setting). The loaded
|
||||||
|
value is the real constraint and must be preferred.
|
||||||
|
"""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
native_resp = self._make_resp(200, {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"key": "nvidia/nvidia-nemotron-3-nano-4b",
|
||||||
|
"id": "nvidia/nvidia-nemotron-3-nano-4b",
|
||||||
|
"max_context_length": 1_048_576,
|
||||||
|
"loaded_instances": [
|
||||||
|
{"config": {"context_length": 122_651}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
client_mock = self._make_client(
|
||||||
|
native_resp,
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
self._make_resp(404, {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length(
|
||||||
|
"nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 122_651, (
|
||||||
|
f"Expected loaded instance context (122651) but got {result}. "
|
||||||
|
"max_context_length (1048576) must not win over loaded_instances."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQueryLocalContextLengthNetworkError:
|
||||||
|
"""_query_local_context_length handles network failures gracefully."""
|
||||||
|
|
||||||
|
def test_connection_error_returns_none(self):
|
||||||
|
"""Returns None when the server is unreachable."""
|
||||||
|
from agent.model_metadata import _query_local_context_length
|
||||||
|
|
||||||
|
client_mock = MagicMock()
|
||||||
|
client_mock.__enter__ = lambda s: client_mock
|
||||||
|
client_mock.__exit__ = MagicMock(return_value=False)
|
||||||
|
client_mock.post.side_effect = Exception("Connection refused")
|
||||||
|
client_mock.get.side_effect = Exception("Connection refused")
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
||||||
|
patch("httpx.Client", return_value=client_mock):
|
||||||
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_model_context_length — integration-style tests with mocked helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGetModelContextLengthLocalFallback:
|
||||||
|
"""get_model_context_length uses local server query before falling back to 2M."""
|
||||||
|
|
||||||
|
def test_local_endpoint_unknown_model_queries_server(self):
|
||||||
|
"""Unknown model on local endpoint gets ctx from server, not 2M default."""
|
||||||
|
from agent.model_metadata import get_model_context_length
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||||
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
||||||
|
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
|
||||||
|
patch("agent.model_metadata.save_context_length") as mock_save:
|
||||||
|
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result == 131072
|
||||||
|
|
||||||
|
def test_local_endpoint_unknown_model_result_is_cached(self):
|
||||||
|
"""Context length returned from local server is persisted to cache."""
|
||||||
|
from agent.model_metadata import get_model_context_length
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||||
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
||||||
|
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
|
||||||
|
patch("agent.model_metadata.save_context_length") as mock_save:
|
||||||
|
get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072)
|
||||||
|
|
||||||
|
def test_local_endpoint_server_returns_none_falls_back_to_2m(self):
|
||||||
|
"""When local server returns None, still falls back to 2M probe tier."""
|
||||||
|
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||||
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
||||||
|
patch("agent.model_metadata._query_local_context_length", return_value=None):
|
||||||
|
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result == CONTEXT_PROBE_TIERS[0]
|
||||||
|
|
||||||
|
def test_non_local_endpoint_does_not_query_local_server(self):
|
||||||
|
"""For non-local endpoints, _query_local_context_length is not called."""
|
||||||
|
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||||
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.is_local_endpoint", return_value=False), \
|
||||||
|
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
||||||
|
result = get_model_context_length(
|
||||||
|
"unknown-model", "https://some-cloud-api.example.com/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_query.assert_not_called()
|
||||||
|
|
||||||
|
def test_cached_result_skips_local_query(self):
|
||||||
|
"""Cached context length is returned without querying the local server."""
|
||||||
|
from agent.model_metadata import get_model_context_length
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \
|
||||||
|
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
||||||
|
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
||||||
|
|
||||||
|
assert result == 65536
|
||||||
|
mock_query.assert_not_called()
|
||||||
|
|
||||||
|
def test_no_base_url_does_not_query_local_server(self):
|
||||||
|
"""When base_url is empty, local server is not queried."""
|
||||||
|
from agent.model_metadata import get_model_context_length
|
||||||
|
|
||||||
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||||
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
||||||
|
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
||||||
|
result = get_model_context_length("unknown-xyz-model", "")
|
||||||
|
|
||||||
|
mock_query.assert_not_called()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue