fix: make Nous Portal access token resolution resilient

- Track auth store source path on Nous state reads and write rotated
  OAuth refresh tokens back to the same store, preventing stale-token
  replays when Hermes falls back to a global/root auth.json.
- Skip Nous fallback entries locally when no access/refresh token is
  present, suppressing repeated failed resolution attempts within a
  session.
- Sync session model metadata after fallback switches so the gateway
  DB reflects the backend that actually served the latest turn.
This commit is contained in:
HODLCLONE 2026-06-19 20:20:01 -04:00 committed by Teknium
parent cfbc7ed1f9
commit 6ed2f5d76f
4 changed files with 276 additions and 32 deletions

View file

@ -1124,6 +1124,35 @@ def rewrite_prompt_model_identity(agent, model: str, provider: str) -> None:
agent._cached_system_prompt = sp
def _fallback_entry_key(fb: dict) -> tuple[str, str, str]:
return (
str(fb.get("provider") or "").strip().lower(),
str(fb.get("model") or "").strip(),
str(fb.get("base_url") or "").strip().rstrip("/"),
)
def _fallback_entry_unavailable_without_network(agent, fb: dict) -> Optional[str]:
"""Return a skip reason for fallback entries known to be unusable locally."""
fb_provider = (fb.get("provider") or "").strip().lower()
if fb_provider != "nous":
return None
try:
from hermes_cli.auth import get_provider_auth_state
state = get_provider_auth_state("nous") or {}
except Exception as exc:
return f"nous_auth_unreadable:{type(exc).__name__}"
access_value = state.get("access_token")
refresh_value = state.get("refresh_token")
has_access = isinstance(access_value, str) and bool(access_value.strip())
has_refresh = isinstance(refresh_value, str) and bool(refresh_value.strip())
if not (has_access or has_refresh):
return "nous_token_missing"
return None
def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool:
"""Switch to the next fallback model/provider in the chain.
@ -1164,10 +1193,29 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool
return False
fb = agent._fallback_chain[agent._fallback_index]
agent._fallback_index += 1
fb_key = _fallback_entry_key(fb)
unavailable = getattr(agent, "_unavailable_fallback_keys", None)
if unavailable is None:
unavailable = set()
agent._unavailable_fallback_keys = unavailable
if fb_key in unavailable:
logger.debug("Fallback skip: %s previously marked unavailable", fb_key)
return agent._try_activate_fallback(reason)
fb_provider = (fb.get("provider") or "").strip().lower()
fb_model = (fb.get("model") or "").strip()
if not fb_provider or not fb_model:
return agent._try_activate_fallback() # skip invalid, try next
return agent._try_activate_fallback(reason) # skip invalid, try next
local_skip_reason = _fallback_entry_unavailable_without_network(agent, fb)
if local_skip_reason:
unavailable.add(fb_key)
logger.warning(
"Fallback skip: %s/%s is not locally usable (%s); suppressing for this session",
fb_provider,
fb_model,
local_skip_reason,
)
return agent._try_activate_fallback(reason)
# Skip entries that resolve to the current (provider, model) — falling
# back to the same backend that just failed loops the failure. Compare
@ -1182,7 +1230,7 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool
"Fallback skip: chain entry %s/%s matches current provider/model",
fb_provider, fb_model,
)
return agent._try_activate_fallback()
return agent._try_activate_fallback(reason)
if (
fb_base_url_for_dedup
and current_base_url
@ -1193,7 +1241,7 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool
"Fallback skip: chain entry base_url %s matches current backend",
fb_base_url_for_dedup,
)
return agent._try_activate_fallback()
return agent._try_activate_fallback(reason)
# Use centralized router for client construction.
# raw_codex=True because the main agent needs direct responses.stream()
@ -1224,7 +1272,8 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool
logger.warning(
"Fallback to %s failed: provider not configured",
fb_provider)
return agent._try_activate_fallback() # try next in chain
unavailable.add(fb_key)
return agent._try_activate_fallback(reason) # try next in chain
try:
from hermes_cli.model_normalize import normalize_model_for_provider
@ -1425,8 +1474,10 @@ def try_activate_fallback(agent, reason: "FailoverReason | None" = None) -> bool
)
return True
except Exception as e:
if fb_provider == "nous":
unavailable.add(fb_key)
logger.error("Failed to activate fallback %s: %s", fb_model, e)
return agent._try_activate_fallback() # try next in chain
return agent._try_activate_fallback(reason) # try next in chain

View file

@ -3690,6 +3690,62 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
route["request_overrides"] = overrides or {}
return route
def _sync_session_model_from_agent(self, session_id: str, agent: Any) -> None:
"""Persist the runtime model/provider actually used by a gateway turn.
Provider fallback can switch ``agent.model``/``agent.provider`` after the
session row was created. Keep the session DB metadata in sync so session
lists, desktop/dashboard details, and follow-up session tooling report the
backend that actually answered the latest turn.
"""
if not session_id or agent is None or self._session_db is None:
return
model = getattr(agent, "model", None)
if not model:
return
runtime = {
"provider": getattr(agent, "provider", None),
"base_url": getattr(agent, "base_url", None),
"api_mode": getattr(agent, "api_mode", None),
"fallback_active": bool(getattr(agent, "_fallback_activated", False)),
}
runtime = {k: v for k, v in runtime.items() if v not in (None, "")}
def _do(conn):
import json as _json
row = conn.execute(
"SELECT model, model_config FROM sessions WHERE id = ?",
(session_id,),
).fetchone()
if row is None:
return
try:
current_model = row["model"]
raw_config = row["model_config"]
except Exception:
current_model = row[0]
raw_config = row[1]
try:
config = _json.loads(raw_config) if raw_config else {}
except Exception:
config = {}
if not isinstance(config, dict):
config = {}
gateway_runtime = dict(config.get("gateway_runtime") or {})
if current_model == model and all(gateway_runtime.get(k) == v for k, v in runtime.items()):
return
config["gateway_runtime"] = runtime
conn.execute(
"UPDATE sessions SET model = ?, model_config = ? WHERE id = ?",
(model, _json.dumps(config), session_id),
)
try:
self._session_db._execute_write(_do) # noqa: SLF001 - SessionDB exposes no metadata updater
except Exception:
logger.debug("Failed to sync gateway session model metadata", exc_info=True)
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
"""React to an adapter failure after startup.
@ -17629,6 +17685,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
)
effective_session_id = agent_session_id
self._sync_session_model_from_agent(effective_session_id, agent)
# history_offset=0 whenever the agent's message list no longer has
# the original history prefix — i.e. on rotation (split) OR in-place
# compaction. In both cases the returned `messages` is the compacted

View file

@ -1157,6 +1157,36 @@ def _save_auth_store(auth_store: Dict[str, Any], target_path: Optional[Path] = N
return auth_file
def _load_provider_state_with_source(
auth_store: Dict[str, Any],
provider_id: str,
) -> tuple[Optional[Dict[str, Any]], Optional[Path]]:
"""Return a provider state plus the auth.json path it came from.
Most callers only need the state, but refresh paths that rotate single-use
OAuth refresh tokens must write the updated token chain back to the same
store they read. In profile mode ``_load_provider_state`` can read a
global-root fallback state; persisting a rotated Nous refresh token only to
the profile would leave the global/root store stale and cause the next
process to replay an already-consumed refresh token.
"""
providers = auth_store.get("providers")
if isinstance(providers, dict):
state = providers.get(provider_id)
if isinstance(state, dict):
return dict(state), _auth_file_path()
global_path = _global_auth_file_path()
global_store = _load_global_auth_store()
if global_store:
global_providers = global_store.get("providers")
if isinstance(global_providers, dict):
global_state = global_providers.get(provider_id)
if isinstance(global_state, dict):
return dict(global_state), global_path
return None, None
def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Optional[Dict[str, Any]]:
"""Return a provider's persisted state.
@ -1168,22 +1198,8 @@ def _load_provider_state(auth_store: Dict[str, Any], provider_id: str) -> Option
the profile, the profile state fully shadows the global state on the next
read. See issue #18594 follow-up.
"""
providers = auth_store.get("providers")
if isinstance(providers, dict):
state = providers.get(provider_id)
if isinstance(state, dict):
return dict(state)
# Read-only fallback to the global-root auth store (profile mode only;
# returns empty dict in classic mode so this is a no-op).
global_store = _load_global_auth_store()
if global_store:
global_providers = global_store.get("providers")
if isinstance(global_providers, dict):
global_state = global_providers.get(provider_id)
if isinstance(global_state, dict):
return dict(global_state)
return None
state, _source_path = _load_provider_state_with_source(auth_store, provider_id)
return state
def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Dict[str, Any]) -> None:
@ -1195,6 +1211,30 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di
auth_store["active_provider"] = provider_id
def _save_provider_state_to_source(
auth_store: Dict[str, Any],
provider_id: str,
state: Dict[str, Any],
source_path: Optional[Path],
) -> None:
"""Persist provider state back to the auth store it was read from."""
active_path = _auth_file_path()
if source_path is None:
source_path = active_path
try:
same_store = source_path.resolve(strict=False) == active_path.resolve(strict=False)
except Exception:
same_store = source_path == active_path
if same_store:
_save_provider_state(auth_store, provider_id, state)
_save_auth_store(auth_store)
return
source_store = _load_auth_store(source_path)
_save_provider_state(source_store, provider_id, state)
_save_auth_store(source_store, target_path=source_path)
def _store_provider_state(
auth_store: Dict[str, Any],
provider_id: str,
@ -5337,7 +5377,7 @@ def resolve_nous_access_token(
"""Resolve a refresh-aware Nous Portal access token for managed tool gateways."""
with _auth_store_lock():
auth_store = _load_auth_store()
state = _load_provider_state(auth_store, "nous")
state, state_source_path = _load_provider_state_with_source(auth_store, "nous")
if not state:
raise AuthError(
@ -5377,8 +5417,7 @@ def resolve_nous_access_token(
if not _is_expiring(state.get("expires_at"), refresh_skew_seconds):
if merged_shared:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_save_provider_state_to_source(auth_store, "nous", state, state_source_path)
return access_token
if not isinstance(refresh_token, str) or not refresh_token:
@ -5413,8 +5452,7 @@ def resolve_nous_access_token(
exc,
reason="managed_access_token_refresh_failure",
)
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_save_provider_state_to_source(auth_store, "nous", state, state_source_path)
raise
now = datetime.now(timezone.utc)
@ -5435,8 +5473,7 @@ def resolve_nous_access_token(
"insecure": verify is False,
"ca_bundle": verify if isinstance(verify, str) else None,
}
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_save_provider_state_to_source(auth_store, "nous", state, state_source_path)
_write_shared_nous_state(state)
return state["access_token"]
@ -5662,7 +5699,7 @@ def resolve_nous_runtime_credentials(
with _auth_store_lock():
auth_store = _load_auth_store()
state = _load_provider_state(auth_store, "nous")
state, state_source_path = _load_provider_state_with_source(auth_store, "nous")
if not state:
raise AuthError("Hermes is not logged into Nous Portal.",
@ -5724,8 +5761,7 @@ def resolve_nous_runtime_credentials(
)
return
try:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_save_provider_state_to_source(auth_store, "nous", state, state_source_path)
except Exception as exc:
_oauth_trace(
"nous_state_persist_failed",
@ -5904,7 +5940,7 @@ def resolve_nous_runtime_credentials(
"expires_at": expires_at,
"expires_in": expires_in,
"source": NOUS_AUTH_PATH_INVOKE_JWT,
"auth_path": NOUS_AUTH_PATH_INVOKE_JWT,
"auth_path": str(state_source_path or _auth_file_path()),
}

View file

@ -0,0 +1,100 @@
"""Tests for Nous fallback local-availability suppression.
Blocker if Nous token material is missing locally: the fallback chain
should not repeatedly attempt Nous resolution; it must skip and continue
to the next provider.
"""
from __future__ import annotations
from unittest.mock import patch
from run_agent import AIAgent
def _make_agent(fallback_model=None):
with (
patch("run_agent.get_tool_definitions", return_value=[]),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
fallback_model=fallback_model,
)
agent.client = None
return agent
def _mock_client(base_url="https://openrouter.ai/api/v1", api_key="fb-key"):
mock = type("Client", (), {})()
mock.base_url = base_url
mock.api_key = api_key
mock.chat = type("Chat", (), {})()
mock.chat.completions = type("Completions", (), {})()
mock.chat.completions.create = lambda *args, **kwargs: None
return mock
class TestNousFallbackLocalAvailability:
def test_missing_nous_token_is_skipped_once(self):
"""Nous fallback is skipped when no access/refresh token is stored."""
agent = _make_agent(
fallback_model=[
{"provider": "nous", "model": "anthropic/claude-sonnet-4.6"},
{"provider": "openai", "model": "gpt-4o"},
]
)
with patch(
"hermes_cli.auth.get_provider_auth_state",
return_value={},
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(_mock_client(api_key="fb"), "gpt-4o"),
):
activated = agent._try_activate_fallback(None)
assert activated is True
assert agent.model == "gpt-4o"
def test_nous_unavailable_not_retried_in_same_session(self):
"""After Nous is skipped once, subsequent activations continue further."""
agent = _make_agent(
fallback_model=[
{"provider": "nous", "model": "anthropic/claude-sonnet-4.6"},
{"provider": "openai", "model": "gpt-4o"},
]
)
with patch(
"hermes_cli.auth.get_provider_auth_state",
return_value={},
):
agent._try_activate_fallback(None)
key = (
"nous",
"anthropic/claude-sonnet-4.6",
"",
)
assert key in getattr(agent, "_unavailable_fallback_keys", set())
def test_present_nous_token_allows_activation(self):
"""Nous is considered when token material exists."""
agent = _make_agent(
fallback_model=[
{"provider": "nous", "model": "anthropic/claude-sonnet-4.6"},
{"provider": "openai", "model": "gpt-4o"},
]
)
with patch(
"hermes_cli.auth.get_provider_auth_state",
return_value={"access_token": "abc", "refresh_token": "xyz"},
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(_mock_client(api_key="fb"), "anthropic/claude-sonnet-4.6"),
):
activated = agent._try_activate_fallback(None)
assert activated is True
assert agent.provider == "nous"