fix(providers): pass extra headers to model discovery
This commit is contained in:
parent
80a774f972
commit
ab40e952f3
5 changed files with 211 additions and 18 deletions
|
|
@ -23,7 +23,7 @@ from __future__ import annotations
|
|||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
from hermes_cli.providers import (
|
||||
ProviderDef,
|
||||
|
|
@ -1362,6 +1362,19 @@ import threading as _threading # noqa: E402
|
|||
_picker_prewarm_done = _threading.Event()
|
||||
|
||||
|
||||
def _extra_headers_from_config(entry: Any) -> dict[str, str]:
|
||||
if not isinstance(entry, dict):
|
||||
return {}
|
||||
headers = entry.get("extra_headers")
|
||||
if not isinstance(headers, dict) or not headers:
|
||||
return {}
|
||||
return {
|
||||
str(key): str(value)
|
||||
for key, value in headers.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
|
||||
def prewarm_picker_cache_async() -> Optional["_threading.Thread"]:
|
||||
"""Warm the provider-models disk cache in a background daemon thread.
|
||||
|
||||
|
|
@ -1993,7 +2006,11 @@ def list_authenticated_providers(
|
|||
if should_probe:
|
||||
try:
|
||||
from hermes_cli.models import fetch_api_models
|
||||
live_models = fetch_api_models(api_key, api_url)
|
||||
live_models = fetch_api_models(
|
||||
api_key,
|
||||
api_url,
|
||||
headers=_extra_headers_from_config(ep_cfg) or None,
|
||||
)
|
||||
if live_models:
|
||||
models_list = live_models
|
||||
except Exception:
|
||||
|
|
@ -2130,10 +2147,13 @@ def list_authenticated_providers(
|
|||
"api_key": api_key,
|
||||
"models": [],
|
||||
"discover_models": discover,
|
||||
"extra_headers": _extra_headers_from_config(entry),
|
||||
}
|
||||
else:
|
||||
if api_key and not groups[group_key].get("api_key"):
|
||||
groups[group_key]["api_key"] = api_key
|
||||
if not groups[group_key].get("extra_headers"):
|
||||
groups[group_key]["extra_headers"] = _extra_headers_from_config(entry)
|
||||
# If any entry in this group opts out of discovery,
|
||||
# honour that for the whole grouped row.
|
||||
if not discover:
|
||||
|
|
@ -2240,7 +2260,11 @@ def list_authenticated_providers(
|
|||
try:
|
||||
from hermes_cli.models import fetch_api_models
|
||||
|
||||
live_models = fetch_api_models(api_key, api_url)
|
||||
live_models = fetch_api_models(
|
||||
api_key,
|
||||
api_url,
|
||||
headers=grp.get("extra_headers") or None,
|
||||
)
|
||||
if live_models:
|
||||
grp["models"] = live_models
|
||||
grp["total_models"] = len(live_models)
|
||||
|
|
|
|||
|
|
@ -3451,6 +3451,7 @@ def probe_api_models(
|
|||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
api_mode: Optional[str] = None,
|
||||
request_headers: Optional[dict[str, str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Probe a ``/models`` endpoint with light URL heuristics.
|
||||
|
||||
|
|
@ -3497,6 +3498,16 @@ def probe_api_models(
|
|||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if normalized.startswith(COPILOT_BASE_URL):
|
||||
headers.update(copilot_default_headers())
|
||||
if isinstance(request_headers, dict):
|
||||
# Per-provider custom headers can contain auth/proxy secrets. Merge
|
||||
# last so endpoint-specific config wins, and never log the values.
|
||||
headers.update(
|
||||
{
|
||||
str(key): str(value)
|
||||
for key, value in request_headers.items()
|
||||
if value is not None
|
||||
}
|
||||
)
|
||||
|
||||
for candidate_base, is_fallback in candidates:
|
||||
url = candidate_base.rstrip("/") + "/models"
|
||||
|
|
@ -3529,13 +3540,20 @@ def fetch_api_models(
|
|||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
api_mode: Optional[str] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
) -> Optional[list[str]]:
|
||||
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
|
||||
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
be reached (network error, timeout, auth failure, etc.).
|
||||
"""
|
||||
return probe_api_models(api_key, base_url, timeout=timeout, api_mode=api_mode).get("models")
|
||||
return probe_api_models(
|
||||
api_key,
|
||||
base_url,
|
||||
timeout=timeout,
|
||||
api_mode=api_mode,
|
||||
request_headers=headers,
|
||||
).get("models")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -4,11 +4,14 @@ PR #3526 salvage — user-configurable extra HTTP headers on LLM API calls
|
|||
(reverse proxies, gateways, custom auth such as Cloudflare Access tokens).
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from hermes_cli.config import (
|
||||
_normalize_custom_provider_entry,
|
||||
apply_custom_provider_extra_headers_to_client_kwargs,
|
||||
get_custom_provider_extra_headers,
|
||||
)
|
||||
from hermes_cli import models as models_mod
|
||||
|
||||
|
||||
def test_normalize_entry_keeps_extra_headers():
|
||||
|
|
@ -125,3 +128,43 @@ def test_apply_extra_headers_noop_without_match():
|
|||
custom_providers=providers,
|
||||
)
|
||||
assert "default_headers" not in client_kwargs
|
||||
|
||||
|
||||
def test_fetch_api_models_sends_extra_headers_to_models_probe(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return json.dumps({"data": [{"id": "proxy-model"}]}).encode()
|
||||
|
||||
def fake_urlopen(request, timeout=0):
|
||||
captured["url"] = request.full_url
|
||||
captured["timeout"] = timeout
|
||||
captured["headers"] = {
|
||||
key.lower(): value
|
||||
for key, value in request.header_items()
|
||||
}
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr(models_mod.urllib.request, "urlopen", fake_urlopen)
|
||||
|
||||
models = models_mod.fetch_api_models(
|
||||
"proxy-key",
|
||||
"https://llm.internal.example.com/v1",
|
||||
headers={
|
||||
"sleeve-harness": "hermes",
|
||||
"sleeve-base-url": "http://localhost:8081/v1",
|
||||
},
|
||||
)
|
||||
|
||||
assert models == ["proxy-model"]
|
||||
assert captured["url"] == "https://llm.internal.example.com/v1/models"
|
||||
assert captured["headers"]["authorization"] == "Bearer proxy-key"
|
||||
assert captured["headers"]["sleeve-harness"] == "hermes"
|
||||
assert captured["headers"]["sleeve-base-url"] == "http://localhost:8081/v1"
|
||||
|
|
|
|||
|
|
@ -643,8 +643,8 @@ def test_custom_providers_uses_live_models_for_multi_model_endpoint(monkeypatch)
|
|||
|
||||
calls = []
|
||||
|
||||
def fake_fetch_api_models(api_key, base_url):
|
||||
calls.append((api_key, base_url))
|
||||
def fake_fetch_api_models(api_key, base_url, **kwargs):
|
||||
calls.append((api_key, base_url, kwargs))
|
||||
return ["gateway-model-a", "gateway-model-b", "gateway-model-c"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
|
||||
|
|
@ -679,9 +679,9 @@ def test_custom_providers_uses_live_models_for_multi_model_endpoint(monkeypatch)
|
|||
)
|
||||
|
||||
assert gateway_prov is not None, "Custom provider group not found in results"
|
||||
assert calls == [("sk-gateway-key", "https://gateway.example.com/v1")], (
|
||||
"fetch_api_models must be called with the custom provider's credentials"
|
||||
)
|
||||
assert calls == [
|
||||
("sk-gateway-key", "https://gateway.example.com/v1", {"headers": None})
|
||||
], "fetch_api_models must be called with the custom provider's credentials"
|
||||
assert gateway_prov["models"] == [
|
||||
"gateway-model-a",
|
||||
"gateway-model-b",
|
||||
|
|
@ -690,6 +690,61 @@ def test_custom_providers_uses_live_models_for_multi_model_endpoint(monkeypatch)
|
|||
assert gateway_prov["total_models"] == 3
|
||||
|
||||
|
||||
def test_custom_provider_live_model_probe_uses_extra_headers(monkeypatch):
|
||||
"""custom_providers[].extra_headers must apply to live /models probes."""
|
||||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_fetch_api_models(api_key, base_url, **kwargs):
|
||||
calls.append((api_key, base_url, kwargs))
|
||||
return ["gateway-model"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="openrouter",
|
||||
current_base_url="https://openrouter.ai/api/v1",
|
||||
custom_providers=[
|
||||
{
|
||||
"name": "LLM Proxy",
|
||||
"api_key": "local-key",
|
||||
"base_url": "http://localhost:8081/v1",
|
||||
"extra_headers": {
|
||||
"sleeve-harness": "hermes",
|
||||
"sleeve-base-url": "http://localhost:8081/v1",
|
||||
},
|
||||
}
|
||||
],
|
||||
max_models=50,
|
||||
)
|
||||
|
||||
gateway_prov = next(
|
||||
(
|
||||
p
|
||||
for p in providers
|
||||
if p.get("api_url") == "http://localhost:8081/v1"
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
assert gateway_prov is not None
|
||||
assert calls == [
|
||||
(
|
||||
"local-key",
|
||||
"http://localhost:8081/v1",
|
||||
{
|
||||
"headers": {
|
||||
"sleeve-harness": "hermes",
|
||||
"sleeve-base-url": "http://localhost:8081/v1",
|
||||
}
|
||||
},
|
||||
)
|
||||
]
|
||||
assert gateway_prov["models"] == ["gateway-model"]
|
||||
|
||||
|
||||
def test_custom_providers_discover_models_false_keeps_explicit_subset(monkeypatch):
|
||||
"""Custom providers (section 4) with ``discover_models: false`` must keep
|
||||
their explicit ``models:`` subset instead of replacing it with live
|
||||
|
|
@ -704,8 +759,8 @@ def test_custom_providers_discover_models_false_keeps_explicit_subset(monkeypatc
|
|||
|
||||
calls = []
|
||||
|
||||
def fake_fetch_api_models(api_key, base_url):
|
||||
calls.append((api_key, base_url))
|
||||
def fake_fetch_api_models(api_key, base_url, **kwargs):
|
||||
calls.append((api_key, base_url, kwargs))
|
||||
return ["gateway-model-a", "gateway-model-b", "gateway-model-c"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
|
||||
|
|
@ -760,8 +815,8 @@ def test_custom_providers_discover_models_false_string_is_normalised(monkeypatch
|
|||
|
||||
calls = []
|
||||
|
||||
def fake_fetch_api_models(api_key, base_url):
|
||||
calls.append((api_key, base_url))
|
||||
def fake_fetch_api_models(api_key, base_url, **kwargs):
|
||||
calls.append((api_key, base_url, kwargs))
|
||||
return ["live-a", "live-b"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
|
||||
|
|
|
|||
|
|
@ -144,8 +144,8 @@ def test_list_authenticated_providers_uses_live_models_for_user_provider(monkeyp
|
|||
|
||||
calls = []
|
||||
|
||||
def fake_fetch_api_models(api_key, base_url):
|
||||
calls.append((api_key, base_url))
|
||||
def fake_fetch_api_models(api_key, base_url, **kwargs):
|
||||
calls.append((api_key, base_url, kwargs))
|
||||
return ["old-configured-model", "new-live-model"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
|
||||
|
|
@ -175,11 +175,62 @@ def test_list_authenticated_providers_uses_live_models_for_user_provider(monkeyp
|
|||
)
|
||||
|
||||
assert user_prov is not None
|
||||
assert calls == [("sk-test", "http://127.0.0.1:3000/api/v1")]
|
||||
assert calls == [("sk-test", "http://127.0.0.1:3000/api/v1", {"headers": None})]
|
||||
assert user_prov["models"] == ["old-configured-model", "new-live-model"]
|
||||
assert user_prov["total_models"] == 2
|
||||
|
||||
|
||||
def test_user_provider_live_model_probe_uses_extra_headers(monkeypatch):
|
||||
"""providers.<name>.extra_headers must also apply to live /models probes."""
|
||||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_fetch_api_models(api_key, base_url, **kwargs):
|
||||
calls.append((api_key, base_url, kwargs))
|
||||
return ["live-model"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", fake_fetch_api_models)
|
||||
|
||||
providers = list_authenticated_providers(
|
||||
current_provider="llm-proxy",
|
||||
user_providers={
|
||||
"llm-proxy": {
|
||||
"name": "LLM Proxy",
|
||||
"base_url": "http://localhost:8081/v1",
|
||||
"api_key": "local-key",
|
||||
"extra_headers": {
|
||||
"sleeve-harness": "hermes",
|
||||
"sleeve-base-url": "http://localhost:8081/v1",
|
||||
},
|
||||
}
|
||||
},
|
||||
custom_providers=[],
|
||||
max_models=50,
|
||||
)
|
||||
|
||||
user_prov = next(
|
||||
(p for p in providers if p.get("is_user_defined") and p["slug"] == "llm-proxy"),
|
||||
None,
|
||||
)
|
||||
|
||||
assert user_prov is not None
|
||||
assert calls == [
|
||||
(
|
||||
"local-key",
|
||||
"http://localhost:8081/v1",
|
||||
{
|
||||
"headers": {
|
||||
"sleeve-harness": "hermes",
|
||||
"sleeve-base-url": "http://localhost:8081/v1",
|
||||
}
|
||||
},
|
||||
)
|
||||
]
|
||||
assert user_prov["models"] == ["live-model"]
|
||||
|
||||
|
||||
def test_list_authenticated_providers_dict_models_without_default_model(monkeypatch):
|
||||
"""Dict-format ``models:`` without a ``default_model`` must still expose
|
||||
every dict key, not collapse to an empty list."""
|
||||
|
|
@ -1063,10 +1114,11 @@ def test_section3_probes_no_key_endpoint_without_explicit_models(monkeypatch):
|
|||
|
||||
probed = {}
|
||||
|
||||
def _fake_fetch(api_key, api_url):
|
||||
def _fake_fetch(api_key, api_url, **kwargs):
|
||||
probed["called"] = True
|
||||
probed["api_key"] = api_key
|
||||
probed["api_url"] = api_url
|
||||
probed["kwargs"] = kwargs
|
||||
return ["live-model-1", "live-model-2", "live-model-3"]
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", _fake_fetch)
|
||||
|
|
@ -1088,6 +1140,7 @@ def test_section3_probes_no_key_endpoint_without_explicit_models(monkeypatch):
|
|||
|
||||
assert probed.get("called") is True, "no-key bare endpoint should be probed"
|
||||
assert probed["api_key"] == ""
|
||||
assert probed["kwargs"] == {"headers": None}
|
||||
row = next(p for p in providers if p["slug"] == "local-llamacpp")
|
||||
assert row["models"] == ["live-model-1", "live-model-2", "live-model-3"]
|
||||
assert row["total_models"] == 3
|
||||
|
|
@ -1099,7 +1152,7 @@ def test_section3_skips_probe_when_no_key_but_explicit_models(monkeypatch):
|
|||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr("hermes_cli.providers.HERMES_OVERLAYS", {})
|
||||
|
||||
def _fail_fetch(api_key, api_url):
|
||||
def _fail_fetch(api_key, api_url, **kwargs):
|
||||
raise AssertionError("should not probe when explicit models are set")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.models.fetch_api_models", _fail_fetch)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue