feat(gateway): persist per-session /model overrides across gateway restarts
Per-session /model overrides (_session_model_overrides) were in-memory only, so a gateway restart silently reverted every session to the global default model. Persist the non-secret parts (model/provider/base_url ONLY — never api_key) into the session entry in sessions.json and lazily rehydrate them on first use after a restart, re-resolving credentials through the normal runtime provider resolution. - gateway/session.py: SessionEntry.model_override field with sanitize_model_override() (allowlist: model/provider/base_url) applied on both serialization and deserialization; SessionStore.set_model_override / get_model_override accessors. reset_session() already creates a fresh entry, so /new keeps its clear-on-reset semantics — a restart cannot resurrect an override the user reset away. - gateway/slash_commands.py: write-through at both /model set sites (text command + picker) after storing the in-memory override. - gateway/run.py: _rehydrate_session_model_override() called from _resolve_session_agent_runtime(); in-memory state always wins, credentials are re-resolved per provider (credential-less fallback on failure). Session expiry finalization also drops the persisted override. - tests/gateway/test_session_model_override_persistence.py: restart round-trip, /new clearing, api_key-never-serialized (including tampered sessions.json), rehydration + live-state precedence + credential-failure degradation. Salvaged from #3659 by @Git-on-my-level, narrowed to the restart-persistence gap confirmed in triage.
This commit is contained in:
parent
b98baa3039
commit
30e947e0a0
5 changed files with 396 additions and 0 deletions
|
|
@ -3632,6 +3632,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
resolved_session_key = None
|
||||
|
||||
model = _resolve_gateway_model(user_config)
|
||||
if resolved_session_key:
|
||||
self._rehydrate_session_model_override(resolved_session_key)
|
||||
override = self._session_model_overrides.get(resolved_session_key) if resolved_session_key else None
|
||||
if override:
|
||||
override_model = override.get("model", model)
|
||||
|
|
@ -7448,6 +7450,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
_update_prompt_pending.pop(key, None)
|
||||
with self.session_store._lock:
|
||||
entry.expiry_finalized = True
|
||||
# Session finalization is a conversation boundary —
|
||||
# drop the persisted /model override too so a later
|
||||
# message doesn't rehydrate it after the in-memory
|
||||
# override was popped above.
|
||||
entry.model_override = None
|
||||
self.session_store._save()
|
||||
logger.debug(
|
||||
"Session expiry finalized for %s",
|
||||
|
|
@ -15157,6 +15164,63 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
)
|
||||
return hashlib.sha256(blob.encode()).hexdigest()[:16]
|
||||
|
||||
def _rehydrate_session_model_override(self, session_key: str) -> None:
|
||||
"""Lazily restore a persisted /model override after a gateway restart.
|
||||
|
||||
``_session_model_overrides`` is in-memory only, so before persistence
|
||||
a restart silently reverted every session to the global default model.
|
||||
The non-secret parts (model/provider/base_url) are written through to
|
||||
the session store when /model runs (and cleared on /new); here we read
|
||||
them back on first use and re-resolve credentials via the normal
|
||||
runtime provider resolution — api_key is never persisted to disk.
|
||||
|
||||
No-op when an in-memory override already exists (live state wins) or
|
||||
when the store has nothing persisted (e.g. the user ran /new, which
|
||||
clears both the in-memory dict and the persisted field).
|
||||
"""
|
||||
if session_key in self._session_model_overrides:
|
||||
return
|
||||
store = getattr(self, "session_store", None)
|
||||
if store is None:
|
||||
return
|
||||
try:
|
||||
persisted = store.get_model_override(session_key)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to read persisted session model override", exc_info=True
|
||||
)
|
||||
return
|
||||
if not persisted:
|
||||
return
|
||||
override: Dict[str, Any] = {
|
||||
"model": persisted.get("model"),
|
||||
"provider": persisted.get("provider"),
|
||||
"base_url": persisted.get("base_url"),
|
||||
}
|
||||
provider = persisted.get("provider")
|
||||
if provider:
|
||||
# Re-resolve credentials for the persisted provider. On failure
|
||||
# (e.g. credentials were removed since the switch) keep the
|
||||
# credential-less override — _resolve_session_agent_runtime falls
|
||||
# back to env-based resolution and applies model/provider on top.
|
||||
try:
|
||||
runtime = _resolve_runtime_agent_kwargs_for_provider(provider)
|
||||
override["api_key"] = runtime.get("api_key")
|
||||
override["api_mode"] = runtime.get("api_mode")
|
||||
if not override.get("base_url"):
|
||||
override["base_url"] = runtime.get("base_url")
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Credential re-resolution failed for persisted override "
|
||||
"(provider=%s); using credential-less override",
|
||||
provider, exc_info=True,
|
||||
)
|
||||
self._session_model_overrides[session_key] = override
|
||||
logger.info(
|
||||
"Rehydrated persisted /model override for session=%s: model=%s provider=%s",
|
||||
session_key, override.get("model"), provider or "",
|
||||
)
|
||||
|
||||
def _apply_session_model_override(
|
||||
self, session_key: str, model: str, runtime_kwargs: dict
|
||||
) -> tuple:
|
||||
|
|
|
|||
|
|
@ -574,6 +574,31 @@ def build_session_context_prompt(
|
|||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Keys of a /model session override that are safe to persist to disk.
|
||||
# ``api_key`` (and anything else, e.g. ``api_mode`` which is re-derived from
|
||||
# provider resolution) is intentionally excluded: credentials must NEVER be
|
||||
# written to sessions.json. On rehydration after a gateway restart the
|
||||
# runner re-resolves credentials via the normal runtime provider resolution.
|
||||
PERSISTABLE_MODEL_OVERRIDE_KEYS = ("model", "provider", "base_url")
|
||||
|
||||
|
||||
def sanitize_model_override(override: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]:
|
||||
"""Return a copy of *override* containing only persistable, non-secret keys.
|
||||
|
||||
Returns ``None`` when the input is empty/not a dict or no persistable
|
||||
values remain, so callers can store the result directly on
|
||||
``SessionEntry.model_override``.
|
||||
"""
|
||||
if not isinstance(override, dict):
|
||||
return None
|
||||
cleaned = {
|
||||
k: str(v)
|
||||
for k, v in override.items()
|
||||
if k in PERSISTABLE_MODEL_OVERRIDE_KEYS and v not in (None, "")
|
||||
}
|
||||
return cleaned or None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionEntry:
|
||||
"""
|
||||
|
|
@ -644,6 +669,15 @@ class SessionEntry:
|
|||
resume_reason: Optional[str] = None # e.g. "restart_timeout"
|
||||
last_resume_marked_at: Optional[datetime] = None
|
||||
|
||||
# Session-scoped /model override (model/provider/base_url ONLY — never
|
||||
# credentials). ``_session_model_overrides`` in the gateway runner is
|
||||
# in-memory, so before this field a gateway restart silently reverted
|
||||
# every session to the global default model. api_key/api_mode are
|
||||
# re-resolved through the normal runtime provider resolution when the
|
||||
# override is rehydrated after a restart and are never written to disk
|
||||
# (see sanitize_model_override / SessionStore.set_model_override).
|
||||
model_override: Optional[Dict[str, str]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"session_key": self.session_key,
|
||||
|
|
@ -675,6 +709,10 @@ class SessionEntry:
|
|||
"auto_reset_reason": self.auto_reset_reason,
|
||||
"reset_had_activity": self.reset_had_activity,
|
||||
}
|
||||
if self.model_override:
|
||||
# Defence-in-depth: strip credentials even if a caller stored an
|
||||
# unsanitized dict directly on the entry.
|
||||
result["model_override"] = sanitize_model_override(self.model_override)
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
return result
|
||||
|
|
@ -736,6 +774,7 @@ class SessionEntry:
|
|||
was_auto_reset=data.get("was_auto_reset", False),
|
||||
auto_reset_reason=data.get("auto_reset_reason"),
|
||||
reset_had_activity=data.get("reset_had_activity", False),
|
||||
model_override=sanitize_model_override(data.get("model_override")),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1515,6 +1554,37 @@ class SessionStore:
|
|||
entry.origin,
|
||||
)
|
||||
|
||||
def set_model_override(
|
||||
self, session_key: str, override: Optional[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Persist (or clear) the session-scoped /model override.
|
||||
|
||||
Only non-secret keys (model/provider/base_url — see
|
||||
``sanitize_model_override``) are written; ``api_key``/``api_mode``
|
||||
are re-resolved at rehydration time via the normal runtime provider
|
||||
resolution. Pass ``None`` (or a dict with no persistable values)
|
||||
to clear the persisted override, e.g. on /new.
|
||||
"""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
entry = self._entries.get(session_key)
|
||||
if entry is None:
|
||||
return
|
||||
cleaned = sanitize_model_override(override)
|
||||
if entry.model_override == cleaned:
|
||||
return
|
||||
entry.model_override = cleaned
|
||||
self._save()
|
||||
|
||||
def get_model_override(self, session_key: str) -> Optional[Dict[str, str]]:
|
||||
"""Return the persisted /model override for *session_key*, if any."""
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
entry = self._entries.get(session_key)
|
||||
if entry is None:
|
||||
return None
|
||||
return dict(entry.model_override) if entry.model_override else None
|
||||
|
||||
def suspend_session(self, session_key: str) -> bool:
|
||||
"""Mark a session as suspended so it auto-resets on next access.
|
||||
|
||||
|
|
|
|||
|
|
@ -1597,6 +1597,20 @@ class GatewaySlashCommandsMixin:
|
|||
"api_mode": result.api_mode,
|
||||
}
|
||||
|
||||
# Write-through the non-secret parts to the session
|
||||
# store so the picked model survives a gateway restart
|
||||
# (api_key is never persisted).
|
||||
try:
|
||||
_self.session_store.set_model_override(
|
||||
_session_key,
|
||||
_self._session_model_overrides[_session_key],
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to persist session model override",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Evict cached agent so the next turn creates a fresh
|
||||
# agent from the override rather than relying on the
|
||||
# stale cache signature to trigger a rebuild.
|
||||
|
|
@ -1831,6 +1845,19 @@ class GatewaySlashCommandsMixin:
|
|||
"api_mode": result.api_mode,
|
||||
}
|
||||
|
||||
# Write-through the non-secret parts (model/provider/base_url) to
|
||||
# the session store so the override survives a gateway restart.
|
||||
# api_key/api_mode are never persisted — they are re-resolved via
|
||||
# runtime provider resolution on rehydration.
|
||||
try:
|
||||
self.session_store.set_model_override(
|
||||
session_key, self._session_model_overrides[session_key]
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to persist session model override", exc_info=True
|
||||
)
|
||||
|
||||
# Evict cached agent so the next turn creates a fresh agent from the
|
||||
# override rather than relying on cache signature mismatch detection.
|
||||
self._evict_cached_agent(session_key)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ AUTHOR_MAP = {
|
|||
"r0gersm1th@users.noreply.github.com": "r0gersm1th", # PR #3219 salvage (whatsapp bridge: resolve LID sender IDs to phone numbers in the message payload so phone-based allowlists match; commit authored by collaborator r0gersm1th, PR by @ajmeese7)
|
||||
"louis@letsfive.io": "Mibayy", # PR #3296 salvage (status: provider label honors config.yaml model.base_url, not just OPENAI_BASE_URL env)
|
||||
"me@keslerm.com": "keslerm", # PR #3459 salvage (gateway: 'log' tool_progress mode — silent in chat, tool calls appended to ~/.hermes/logs/tool_calls.log via rotating handler; duplicate of #3458 by @dlkakbs who submitted 4 min earlier — both credited)
|
||||
"david.d.zhang@gmail.com": "Git-on-my-level", # PR #3659 salvage (gateway: persist per-session /model overrides across gateway restarts)
|
||||
"tarunravi@gmail.com": "tarunravi", # PR #2696 salvage (api-server: inline MEDIA:<path> image tags as base64 data URLs in final responses so remote OpenAI-compatible frontends can render server-local screenshots; the PR's tool-progress-streaming and SSE-sentinel pieces were independently superseded on main)
|
||||
"aqdrgg19@gmail.com": "VolodymyrBg", # PR #2861 salvage (webhook: drop the unused full request payload from retained _delivery_info entries — up to ~1MB dead weight per delivery for the 1h idempotency TTL)
|
||||
"ohyes9711@gmail.com": "CharmingGroot", # PR #2794 salvage (email: guard msg_data[0][1] against malformed IMAP fetch structures so one bad response can't abort the batch and permanently lose seen-marked messages; Message-ID domain falls back to localhost when EMAIL_ADDRESS lacks '@')
|
||||
|
|
|
|||
234
tests/gateway/test_session_model_override_persistence.py
Normal file
234
tests/gateway/test_session_model_override_persistence.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
"""Per-session /model overrides must survive gateway restarts (#3659 salvage).
|
||||
|
||||
``GatewayRunner._session_model_overrides`` is in-memory, so before persistence
|
||||
a gateway restart silently reverted every session to the global default model.
|
||||
The non-secret parts (model/provider/base_url) are now written through to the
|
||||
session store (``SessionEntry.model_override`` in sessions.json) and lazily
|
||||
rehydrated on first use after a restart, with credentials re-resolved through
|
||||
the normal runtime provider resolution.
|
||||
|
||||
Covers:
|
||||
- the override survives a simulated restart (a second SessionStore instance
|
||||
reading the same sessions dir, and a fresh runner rehydrating from it)
|
||||
- /new (SessionStore.reset_session) clears the persisted override so a
|
||||
restart cannot resurrect it
|
||||
- api_key is NEVER serialized to sessions.json
|
||||
"""
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform
|
||||
from gateway.session import (
|
||||
SessionEntry,
|
||||
SessionSource,
|
||||
SessionStore,
|
||||
sanitize_model_override,
|
||||
)
|
||||
|
||||
OVERRIDE = {
|
||||
"model": "gpt-5o",
|
||||
"provider": "openai",
|
||||
"api_key": "sk-SUPER-SECRET-do-not-persist",
|
||||
"base_url": "https://api.openai.example/v1",
|
||||
"api_mode": "responses",
|
||||
}
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store_factory(tmp_path, monkeypatch):
|
||||
"""Build SessionStores over a shared sessions dir, without SQLite."""
|
||||
|
||||
def _raise():
|
||||
raise RuntimeError("SQLite disabled in test")
|
||||
|
||||
import hermes_state
|
||||
|
||||
monkeypatch.setattr(hermes_state, "SessionDB", _raise)
|
||||
|
||||
def _make() -> SessionStore:
|
||||
store = SessionStore(sessions_dir=tmp_path, config=GatewayConfig())
|
||||
assert store._db is None
|
||||
return store
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
def _sessions_json(tmp_path) -> str:
|
||||
return (tmp_path / "sessions.json").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_override_persists_and_survives_restart(store_factory, tmp_path):
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
session_key = entry.session_key
|
||||
|
||||
store.set_model_override(session_key, OVERRIDE)
|
||||
|
||||
# Simulated restart: a brand-new store instance reads the same dir.
|
||||
store2 = store_factory()
|
||||
persisted = store2.get_model_override(session_key)
|
||||
assert persisted == {
|
||||
"model": "gpt-5o",
|
||||
"provider": "openai",
|
||||
"base_url": "https://api.openai.example/v1",
|
||||
}
|
||||
|
||||
|
||||
def test_api_key_never_serialized(store_factory, tmp_path):
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
|
||||
store.set_model_override(entry.session_key, OVERRIDE)
|
||||
|
||||
raw = _sessions_json(tmp_path)
|
||||
assert "sk-SUPER-SECRET-do-not-persist" not in raw
|
||||
assert "api_key" not in raw
|
||||
# api_mode is re-derived from provider resolution; not persisted either.
|
||||
data = json.loads(raw)
|
||||
stored = data[entry.session_key]["model_override"]
|
||||
assert set(stored) == {"model", "provider", "base_url"}
|
||||
|
||||
|
||||
def test_from_dict_strips_api_key_from_tampered_json():
|
||||
"""Even a hand-edited sessions.json with an api_key must not load one."""
|
||||
store_entry = SessionEntry.from_dict(
|
||||
{
|
||||
"session_key": "k1",
|
||||
"session_id": "s1",
|
||||
"created_at": "2026-01-01T00:00:00",
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
"model_override": {
|
||||
"model": "m1",
|
||||
"provider": "p1",
|
||||
"api_key": "sk-injected",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
}
|
||||
)
|
||||
assert store_entry.model_override == {"model": "m1", "provider": "p1"}
|
||||
|
||||
|
||||
def test_new_clears_persisted_override(store_factory, tmp_path):
|
||||
"""/new resets the session; the persisted override must not survive it."""
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
session_key = entry.session_key
|
||||
|
||||
store.set_model_override(session_key, OVERRIDE)
|
||||
assert store.get_model_override(session_key) is not None
|
||||
|
||||
# /new path -> SessionStore.reset_session creates a fresh entry.
|
||||
new_entry = store.reset_session(session_key)
|
||||
assert new_entry is not None
|
||||
assert store.get_model_override(session_key) is None
|
||||
|
||||
# Restart after /new must NOT resurrect the override.
|
||||
store2 = store_factory()
|
||||
assert store2.get_model_override(session_key) is None
|
||||
assert "gpt-5o" not in _sessions_json(tmp_path)
|
||||
|
||||
|
||||
def _make_runner(store):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_model_overrides = {}
|
||||
runner.session_store = store
|
||||
return runner
|
||||
|
||||
|
||||
def test_runner_rehydrates_override_after_restart(store_factory):
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
session_key = entry.session_key
|
||||
store.set_model_override(session_key, OVERRIDE)
|
||||
|
||||
# Simulated restart: fresh store + fresh runner with an empty in-memory
|
||||
# override map, credentials re-resolved via runtime provider resolution.
|
||||
runner = _make_runner(store_factory())
|
||||
with patch(
|
||||
"gateway.run._resolve_runtime_agent_kwargs_for_provider",
|
||||
return_value={
|
||||
"api_key": "sk-fresh-from-keychain",
|
||||
"api_mode": "responses",
|
||||
"base_url": "https://api.openai.example/v1",
|
||||
"provider": "openai",
|
||||
},
|
||||
):
|
||||
runner._rehydrate_session_model_override(session_key)
|
||||
|
||||
override = runner._session_model_overrides[session_key]
|
||||
assert override["model"] == "gpt-5o"
|
||||
assert override["provider"] == "openai"
|
||||
assert override["base_url"] == "https://api.openai.example/v1"
|
||||
# Credentials come from live resolution, never from disk.
|
||||
assert override["api_key"] == "sk-fresh-from-keychain"
|
||||
assert override["api_mode"] == "responses"
|
||||
|
||||
|
||||
def test_runner_rehydrate_keeps_live_override(store_factory):
|
||||
"""An in-memory override (live gateway state) always wins over disk."""
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
session_key = entry.session_key
|
||||
store.set_model_override(session_key, OVERRIDE)
|
||||
|
||||
runner = _make_runner(store)
|
||||
live = {"model": "live-model", "provider": "anthropic"}
|
||||
runner._session_model_overrides[session_key] = live
|
||||
|
||||
runner._rehydrate_session_model_override(session_key)
|
||||
|
||||
assert runner._session_model_overrides[session_key] is live
|
||||
|
||||
|
||||
def test_runner_rehydrate_noop_without_persisted_override(store_factory):
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
|
||||
runner = _make_runner(store)
|
||||
runner._rehydrate_session_model_override(entry.session_key)
|
||||
|
||||
assert runner._session_model_overrides == {}
|
||||
|
||||
|
||||
def test_runner_rehydrate_survives_credential_resolution_failure(store_factory):
|
||||
"""Missing credentials degrade to a credential-less override, not a crash."""
|
||||
store = store_factory()
|
||||
entry = store.get_or_create_session(_make_source())
|
||||
session_key = entry.session_key
|
||||
store.set_model_override(session_key, OVERRIDE)
|
||||
|
||||
runner = _make_runner(store)
|
||||
with patch(
|
||||
"gateway.run._resolve_runtime_agent_kwargs_for_provider",
|
||||
side_effect=RuntimeError("no credentials"),
|
||||
):
|
||||
runner._rehydrate_session_model_override(session_key)
|
||||
|
||||
override = runner._session_model_overrides[session_key]
|
||||
assert override["model"] == "gpt-5o"
|
||||
assert override.get("api_key") is None
|
||||
|
||||
|
||||
def test_sanitize_model_override():
|
||||
assert sanitize_model_override(None) is None
|
||||
assert sanitize_model_override({}) is None
|
||||
assert sanitize_model_override({"api_key": "sk-x", "api_mode": "chat"}) is None
|
||||
assert sanitize_model_override(OVERRIDE) == {
|
||||
"model": "gpt-5o",
|
||||
"provider": "openai",
|
||||
"base_url": "https://api.openai.example/v1",
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue