fix(providers): pass extra headers to model discovery

This commit is contained in:
helix4u 2026-07-02 14:33:41 -06:00 committed by kshitij
parent 80a774f972
commit ab40e952f3
5 changed files with 211 additions and 18 deletions

View file

@ -23,7 +23,7 @@ from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from typing import List, NamedTuple, Optional
from typing import Any, List, NamedTuple, Optional
from hermes_cli.providers import (
ProviderDef,
@ -1362,6 +1362,19 @@ import threading as _threading # noqa: E402
_picker_prewarm_done = _threading.Event()
def _extra_headers_from_config(entry: Any) -> dict[str, str]:
if not isinstance(entry, dict):
return {}
headers = entry.get("extra_headers")
if not isinstance(headers, dict) or not headers:
return {}
return {
str(key): str(value)
for key, value in headers.items()
if value is not None
}
def prewarm_picker_cache_async() -> Optional["_threading.Thread"]:
"""Warm the provider-models disk cache in a background daemon thread.
@ -1993,7 +2006,11 @@ def list_authenticated_providers(
if should_probe:
try:
from hermes_cli.models import fetch_api_models
live_models = fetch_api_models(api_key, api_url)
live_models = fetch_api_models(
api_key,
api_url,
headers=_extra_headers_from_config(ep_cfg) or None,
)
if live_models:
models_list = live_models
except Exception:
@ -2130,10 +2147,13 @@ def list_authenticated_providers(
"api_key": api_key,
"models": [],
"discover_models": discover,
"extra_headers": _extra_headers_from_config(entry),
}
else:
if api_key and not groups[group_key].get("api_key"):
groups[group_key]["api_key"] = api_key
if not groups[group_key].get("extra_headers"):
groups[group_key]["extra_headers"] = _extra_headers_from_config(entry)
# If any entry in this group opts out of discovery,
# honour that for the whole grouped row.
if not discover:
@ -2240,7 +2260,11 @@ def list_authenticated_providers(
try:
from hermes_cli.models import fetch_api_models
live_models = fetch_api_models(api_key, api_url)
live_models = fetch_api_models(
api_key,
api_url,
headers=grp.get("extra_headers") or None,
)
if live_models:
grp["models"] = live_models
grp["total_models"] = len(live_models)

View file

@ -3451,6 +3451,7 @@ def probe_api_models(
base_url: Optional[str],
timeout: float = 5.0,
api_mode: Optional[str] = None,
request_headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
"""Probe a ``/models`` endpoint with light URL heuristics.
@ -3497,6 +3498,16 @@ def probe_api_models(
headers["Authorization"] = f"Bearer {api_key}"
if normalized.startswith(COPILOT_BASE_URL):
headers.update(copilot_default_headers())
if isinstance(request_headers, dict):
# Per-provider custom headers can contain auth/proxy secrets. Merge
# last so endpoint-specific config wins, and never log the values.
headers.update(
{
str(key): str(value)
for key, value in request_headers.items()
if value is not None
}
)
for candidate_base, is_fallback in candidates:
url = candidate_base.rstrip("/") + "/models"
@ -3529,13 +3540,20 @@ def fetch_api_models(
base_url: Optional[str],
timeout: float = 5.0,
api_mode: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
) -> Optional[list[str]]:
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
Returns a list of model ID strings, or ``None`` if the endpoint could not
be reached (network error, timeout, auth failure, etc.).
"""
return probe_api_models(api_key, base_url, timeout=timeout, api_mode=api_mode).get("models")
return probe_api_models(
api_key,
base_url,
timeout=timeout,
api_mode=api_mode,
request_headers=headers,
).get("models")
# ---------------------------------------------------------------------------

View file

@ -4,11 +4,14 @@ PR #3526 salvage — user-configurable extra HTTP headers on LLM API calls
(reverse proxies, gateways, custom auth such as Cloudflare Access tokens).
"""
import json
from hermes_cli.config import (
_normalize_custom_provider_entry,
apply_custom_provider_extra_headers_to_client_kwargs,
get_custom_provider_extra_headers,
)
from hermes_cli import models as models_mod
def test_normalize_entry_keeps_extra_headers():
@ -125,3 +128,43 @@ def test_apply_extra_headers_noop_without_match():
custom_providers=providers,
)
assert "default_headers" not in client_kwargs
def test_fetch_api_models_sends_extra_headers_to_models_probe(monkeypatch):
captured = {}
class FakeResponse:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def read(self):
return json.dumps({"data": [{"id": "proxy-model"}]}).encode()
def fake_urlopen(request, timeout=0):
captured["url"] = request.full_url
captured["timeout"] = timeout
captured["headers"] = {
key.lower(): value
for key, value in request.header_items()
}
return FakeResponse()
monkeypatch.setattr(models_mod.urllib.request, "urlopen", fake_urlopen)
models = models_mod.fetch_api_models(
"proxy-key",
"https://llm.internal.example.com/v1",
headers={
"sleeve-harness": "hermes",
"sleeve-base-url": "http://localhost:8081/v1",
},
)
assert models == ["proxy-model"]
assert captured["url"] == "https://llm.internal.example.com/v1/models"
assert captured["headers"]["authorization"] == "Bearer proxy-key"
assert captured["headers"]["sleeve-harness"] == "hermes"
assert captured["headers"]["sleeve-base-url"] == "http://localhost:8081/v1"

View file

@ -643,8 +643,8 @@ def test_custom_providers_uses_live_models_for_multi_model_endpoint(monkeypatch)
calls = []
def fake_fetch_api_models(api_key, base_url):
calls.append((api_key, base_url))
def fake_fetch_api_models(api_key, base_url, **kwargs):
calls.append((api_key, base_url, kwargs))
return ["gateway-model-a", "gateway-model-b", "gateway-model-c"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
@ -679,9 +679,9 @@ def test_custom_providers_uses_live_models_for_multi_model_endpoint(monkeypatch)
)
assert gateway_prov is not None, "Custom provider group not found in results"
assert calls == [("sk-gateway-key", "https://gateway.example.com/v1")], (
"fetch_api_models must be called with the custom provider's credentials"
)
assert calls == [
("sk-gateway-key", "https://gateway.example.com/v1", {"headers": None})
], "fetch_api_models must be called with the custom provider's credentials"
assert gateway_prov["models"] == [
"gateway-model-a",
"gateway-model-b",
@ -690,6 +690,61 @@ def test_custom_providers_uses_live_models_for_multi_model_endpoint(monkeypatch)
assert gateway_prov["total_models"] == 3
def test_custom_provider_live_model_probe_uses_extra_headers(monkeypatch):
"""custom_providers[].extra_headers must apply to live /models probes."""
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
calls = []
def fake_fetch_api_models(api_key, base_url, **kwargs):
calls.append((api_key, base_url, kwargs))
return ["gateway-model"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
providers = list_authenticated_providers(
current_provider="openrouter",
current_base_url="https://openrouter.ai/api/v1",
custom_providers=[
{
"name": "LLM Proxy",
"api_key": "local-key",
"base_url": "http://localhost:8081/v1",
"extra_headers": {
"sleeve-harness": "hermes",
"sleeve-base-url": "http://localhost:8081/v1",
},
}
],
max_models=50,
)
gateway_prov = next(
(
p
for p in providers
if p.get("api_url") == "http://localhost:8081/v1"
),
None,
)
assert gateway_prov is not None
assert calls == [
(
"local-key",
"http://localhost:8081/v1",
{
"headers": {
"sleeve-harness": "hermes",
"sleeve-base-url": "http://localhost:8081/v1",
}
},
)
]
assert gateway_prov["models"] == ["gateway-model"]
def test_custom_providers_discover_models_false_keeps_explicit_subset(monkeypatch):
"""Custom providers (section 4) with ``discover_models: false`` must keep
their explicit ``models:`` subset instead of replacing it with live
@ -704,8 +759,8 @@ def test_custom_providers_discover_models_false_keeps_explicit_subset(monkeypatc
calls = []
def fake_fetch_api_models(api_key, base_url):
calls.append((api_key, base_url))
def fake_fetch_api_models(api_key, base_url, **kwargs):
calls.append((api_key, base_url, kwargs))
return ["gateway-model-a", "gateway-model-b", "gateway-model-c"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
@ -760,8 +815,8 @@ def test_custom_providers_discover_models_false_string_is_normalised(monkeypatch
calls = []
def fake_fetch_api_models(api_key, base_url):
calls.append((api_key, base_url))
def fake_fetch_api_models(api_key, base_url, **kwargs):
calls.append((api_key, base_url, kwargs))
return ["live-a", "live-b"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)

View file

@ -144,8 +144,8 @@ def test_list_authenticated_providers_uses_live_models_for_user_provider(monkeyp
calls = []
def fake_fetch_api_models(api_key, base_url):
calls.append((api_key, base_url))
def fake_fetch_api_models(api_key, base_url, **kwargs):
calls.append((api_key, base_url, kwargs))
return ["old-configured-model", "new-live-model"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
@ -175,11 +175,62 @@ def test_list_authenticated_providers_uses_live_models_for_user_provider(monkeyp
)
assert user_prov is not None
assert calls == [("sk-test", "http://127.0.0.1:3000/api/v1")]
assert calls == [("sk-test", "http://127.0.0.1:3000/api/v1", {"headers": None})]
assert user_prov["models"] == ["old-configured-model", "new-live-model"]
assert user_prov["total_models"] == 2
def test_user_provider_live_model_probe_uses_extra_headers(monkeypatch):
"""providers.<name>.extra_headers must also apply to live /models probes."""
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
calls = []
def fake_fetch_api_models(api_key, base_url, **kwargs):
calls.append((api_key, base_url, kwargs))
return ["live-model"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
providers = list_authenticated_providers(
current_provider="llm-proxy",
user_providers={
"llm-proxy": {
"name": "LLM Proxy",
"base_url": "http://localhost:8081/v1",
"api_key": "local-key",
"extra_headers": {
"sleeve-harness": "hermes",
"sleeve-base-url": "http://localhost:8081/v1",
},
}
},
custom_providers=[],
max_models=50,
)
user_prov = next(
(p for p in providers if p.get("is_user_defined") and p["slug"] == "llm-proxy"),
None,
)
assert user_prov is not None
assert calls == [
(
"local-key",
"http://localhost:8081/v1",
{
"headers": {
"sleeve-harness": "hermes",
"sleeve-base-url": "http://localhost:8081/v1",
}
},
)
]
assert user_prov["models"] == ["live-model"]
def test_list_authenticated_providers_dict_models_without_default_model(monkeypatch):
"""Dict-format ``models:`` without a ``default_model`` must still expose
every dict key, not collapse to an empty list."""
@ -1063,10 +1114,11 @@ def test_section3_probes_no_key_endpoint_without_explicit_models(monkeypatch):
probed = {}
def _fake_fetch(api_key, api_url):
def _fake_fetch(api_key, api_url, **kwargs):
probed["called"] = True
probed["api_key"] = api_key
probed["api_url"] = api_url
probed["kwargs"] = kwargs
return ["live-model-1", "live-model-2", "live-model-3"]
monkeypatch.setattr("hermes_cli.models.fetch_api_models", _fake_fetch)
@ -1088,6 +1140,7 @@ def test_section3_probes_no_key_endpoint_without_explicit_models(monkeypatch):
assert probed.get("called") is True, "no-key bare endpoint should be probed"
assert probed["api_key"] == ""
assert probed["kwargs"] == {"headers": None}
row = next(p for p in providers if p["slug"] == "local-llamacpp")
assert row["models"] == ["live-model-1", "live-model-2", "live-model-3"]
assert row["total_models"] == 3
@ -1099,7 +1152,7 @@ def test_section3_skips_probe_when_no_key_but_explicit_models(monkeypatch):
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
def _fail_fetch(api_key, api_url):
def _fail_fetch(api_key, api_url, **kwargs):
raise AssertionError("should not probe when explicit models are set")
monkeypatch.setattr("hermes_cli.models.fetch_api_models", _fail_fetch)