feat(gateway): per-channel model and system prompt overrides (Fixes #1955)

- ChannelOverride + channel_overrides on PlatformConfig
- Resolve model/runtime: session /model, then channel_overrides, then global
- Thread/parent channel lookup; bridge discord.channel_overrides from YAML
- Drop unrelated test and delegate_tool changes from PR scope
This commit is contained in:
crazywriter1 2026-05-17 16:31:02 +03:00 committed by Teknium
parent ebef73f6b8
commit 0010c14e66
4 changed files with 336 additions and 40 deletions

View file

@ -379,7 +379,7 @@ class PlatformConfig:
# noise; keep True for back-channels where the operator wants them.
gateway_restart_notification: bool = True
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
# while the agent processes a message on this platform. Default True
# preserves prior behavior. Set False on platforms where the indicator is
# unwanted (e.g. Slack's assistant.threads.setStatus "is thinking…", which
@ -420,7 +420,7 @@ class PlatformConfig:
if "home_channel" in data:
home_channel = HomeChannel.from_dict(data["home_channel"])
# gateway_restart_notification may be bridged into extra via the
# gateway_restart_notification may be bridged into extra via the
# shared-key loop in load_gateway_config(); check both top-level
# and extra so YAML ``discord: gateway_restart_notification: false``
# works without needing a separate platforms: block.
@ -448,7 +448,7 @@ class PlatformConfig:
api_key=data.get("api_key"),
home_channel=home_channel,
reply_to_mode=data.get("reply_to_mode", "first"),
gateway_restart_notification=_coerce_bool(_grn, True),
gateway_restart_notification=_coerce_bool(_grn, True),
typing_indicator=_coerce_bool(_typing, True),
channel_overrides=channel_overrides,
extra=data.get("extra", {}),
@ -1103,8 +1103,20 @@ def load_gateway_config() -> GatewayConfig:
bridged["gateway_restart_notification"] = platform_cfg["gateway_restart_notification"]
if "typing_indicator" in platform_cfg:
bridged["typing_indicator"] = platform_cfg["typing_indicator"]
has_channel_overrides = "channel_overrides" in platform_cfg
if has_channel_overrides:
raw_overrides = platform_cfg.get("channel_overrides")
if isinstance(raw_overrides, dict):
plat_data, _extra = _ensure_platform_extra_dict(
platforms_data, plat.value
)
plat_data["channel_overrides"] = {
str(cid): ov_data
for cid, ov_data in raw_overrides.items()
if isinstance(ov_data, dict)
}
enabled_was_explicit = _cfg_toplevel and "enabled" in platform_cfg
if not bridged and not enabled_was_explicit:
if not bridged and not enabled_was_explicit and not has_channel_overrides:
continue
plat_data, extra = _ensure_platform_extra_dict(platforms_data, plat.value)
if enabled_was_explicit:

View file

@ -2306,18 +2306,54 @@ def _resolve_gateway_model(config: dict | None = None) -> str:
return ""
def _channel_override_lookup_keys(
chat_id: str,
*,
thread_id: Optional[str] = None,
parent_id: Optional[str] = None,
) -> list[str]:
"""Ordered, de-duplicated keys for ``channel_overrides`` lookup.
Matches ``resolve_channel_prompt`` semantics: exact thread/channel id first,
then parent channel/forum id (Discord threads inherit parent overrides).
"""
keys: list[str] = []
seen: set[str] = set()
for key in (chat_id, thread_id, parent_id):
if not key:
continue
sk = str(key)
if sk in seen:
continue
seen.add(sk)
keys.append(sk)
return keys
def _get_channel_override(
config: GatewayConfig,
platform: Platform,
chat_id: str,
*,
thread_id: Optional[str] = None,
parent_id: Optional[str] = None,
) -> Optional[ChannelOverride]:
"""Return per-channel override for this platform/chat_id, or None."""
if not chat_id:
return None
"""Return per-channel override for this platform/chat_id, or None.
Looks up ``channel_overrides`` by ``chat_id``, then ``thread_id``, then
``parent_id`` (forum threads / child channels inherit the parent entry).
"""
platform_config = config.platforms.get(platform)
if not platform_config or not platform_config.channel_overrides:
return None
return platform_config.channel_overrides.get(str(chat_id))
overrides = platform_config.channel_overrides
for key in _channel_override_lookup_keys(
chat_id, thread_id=thread_id, parent_id=parent_id
):
ov = overrides.get(key)
if ov is not None:
return ov
return None
def _resolve_hermes_bin() -> Optional[list[str]]:
@ -3579,11 +3615,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
session_key: Optional[str] = None,
user_config: Optional[dict] = None,
) -> tuple[str, dict]:
"""Resolve model/runtime for a session, honoring session-scoped /model overrides.
"""Resolve model/runtime for a session.
If the session override already contains a complete provider bundle
(provider/api_key/base_url/api_mode), prefer it directly instead of
resolving fresh global runtime state first.
Priority (highest first): session ``/model`` ``channel_overrides``
global config/env (``_resolve_gateway_model(user_config)`` and default
provider resolution).
"""
resolved_session_key = session_key
if not resolved_session_key and source is not None:
@ -3632,30 +3668,43 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
runtime_model,
)
model = runtime_model
cfg = getattr(self, "config", None)
if cfg and source is not None:
chat_id = str(source.chat_id) if source.chat_id else ""
thread_id = (
str(source.thread_id) if getattr(source, "thread_id", None) else None
)
parent_id = (
str(source.parent_chat_id)
if getattr(source, "parent_chat_id", None)
else None
)
ch = _get_channel_override(
cfg,
source.platform,
chat_id,
thread_id=thread_id,
parent_id=parent_id,
)
if ch:
if ch.model:
model = ch.model
if ch.provider:
runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(
ch.provider
)
ch_runtime_model = runtime_kwargs.pop("model", None)
# Only adopt the provider's bundled model when the override
# did not specify an explicit model.
if ch_runtime_model and not ch.model:
model = ch_runtime_model
if override and resolved_session_key:
model, runtime_kwargs = self._apply_session_model_override(
resolved_session_key, model, runtime_kwargs
)
cfg = getattr(self, "config", None)
if cfg and source is not None and source.chat_id:
ch = _get_channel_override(cfg, source.platform, str(source.chat_id))
if ch:
channel_touch = False
if ch.model and not (override and override.get("model")):
model = ch.model
channel_touch = True
if ch.provider and not (override and override.get("provider")):
runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(ch.provider)
runtime_model = runtime_kwargs.pop("model", None)
if runtime_model:
model = runtime_model or model
channel_touch = True
if channel_touch and override and resolved_session_key:
model, runtime_kwargs = self._apply_session_model_override(
resolved_session_key, model, runtime_kwargs
)
# When the config has no model.default but a provider was resolved
# (e.g. user ran `hermes auth add openai-codex` without `hermes model`),
# fall back to the provider's first catalog model so the API call
@ -4528,20 +4577,53 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
cfg = _load_gateway_runtime_config()
return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip()
def _resolve_model_for_channel(self, platform: Platform, chat_id: str) -> str:
"""Resolve model for this channel: channel_overrides[channel_id] else global default."""
def _resolve_model_for_channel(
self,
platform: Platform,
chat_id: str,
*,
user_config: Optional[dict] = None,
thread_id: Optional[str] = None,
parent_id: Optional[str] = None,
) -> str:
"""Resolve model for this channel: channel_overrides else global default."""
config = getattr(self, "config", None)
if config:
override = _get_channel_override(config, platform, chat_id)
override = _get_channel_override(
config,
platform,
chat_id,
thread_id=thread_id,
parent_id=parent_id,
)
if override and override.model:
return override.model
return _resolve_gateway_model()
return _resolve_gateway_model(user_config)
def _get_system_prompt_for_channel(self, platform: Platform, chat_id: str) -> str:
"""System prompt for this channel: channel override else global ephemeral."""
def _get_system_prompt_for_channel(
self,
platform: Platform,
chat_id: str,
*,
thread_id: Optional[str] = None,
parent_id: Optional[str] = None,
) -> str:
"""Ephemeral system prompt for this channel/thread.
Uses ``channel_overrides`` when set, else the global gateway prompt.
Legacy ``channel_prompts`` are applied separately via ``event.channel_prompt``
in ``run_sync`` (adapter ``resolve_channel_prompt``), so they are not
duplicated here.
"""
config = getattr(self, "config", None)
if config:
override = _get_channel_override(config, platform, chat_id)
override = _get_channel_override(
config,
platform,
chat_id,
thread_id=thread_id,
parent_id=parent_id,
)
if override and override.system_prompt:
return (override.system_prompt or "").strip()
return getattr(self, "_ephemeral_system_prompt", None) or ""
@ -16872,7 +16954,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if event_channel_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + event_channel_prompt).strip()
cfg_channel_prompt = self._get_system_prompt_for_channel(
source.platform, source.chat_id or ""
source.platform,
source.chat_id or "",
thread_id=getattr(source, "thread_id", None),
parent_id=getattr(source, "parent_chat_id", None),
)
if cfg_channel_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + cfg_channel_prompt).strip()

View file

@ -1,5 +1,7 @@
"""Tests for per-channel model and system prompt overrides (Fixes #1955)."""
from unittest.mock import patch
import pytest
from gateway.config import (
@ -9,6 +11,7 @@ from gateway.config import (
PlatformConfig,
)
from gateway.run import _get_channel_override, GatewayRunner
from gateway.session import SessionSource
class TestGetChannelOverride:
@ -65,6 +68,60 @@ class TestGetChannelOverride:
)
assert _get_channel_override(config, Platform.DISCORD, "123").model == "gpt-4"
def test_thread_id_lookup_when_chat_id_misses(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"thread_99": ChannelOverride(model="topic-model"),
},
),
},
)
result = _get_channel_override(
config, Platform.DISCORD, "parent_chan", thread_id="thread_99"
)
assert result is not None
assert result.model == "topic-model"
def test_parent_id_fallback_when_thread_has_no_entry(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"parent_chan": ChannelOverride(model="parent-model"),
},
),
},
)
result = _get_channel_override(
config,
Platform.DISCORD,
"thread_only",
parent_id="parent_chan",
)
assert result is not None
assert result.model == "parent-model"
def test_exact_thread_overrides_parent(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"thread_1": ChannelOverride(model="thread-model"),
"parent_chan": ChannelOverride(model="parent-model"),
},
),
},
)
result = _get_channel_override(
config, Platform.DISCORD, "thread_1", parent_id="parent_chan"
)
assert result.model == "thread-model"
class TestResolveModelForChannel:
def test_uses_channel_override_when_present(self):
@ -86,7 +143,7 @@ class TestResolveModelForChannel:
def test_falls_back_to_global_when_no_override(self, monkeypatch):
monkeypatch.setattr(
"gateway.run._resolve_gateway_model",
lambda: "global-model/default",
lambda _cfg=None: "global-model/default",
)
config = GatewayConfig(
platforms={
@ -126,3 +183,120 @@ class TestGetSystemPromptForChannel:
runner._ephemeral_system_prompt = "Global prompt"
prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "other")
assert prompt == "Global prompt"
class TestResolveSessionAgentRuntimePriority:
"""Model/runtime priority: session /model → channel_overrides → global."""
def test_channel_override_beats_global(self):
runner = object.__new__(GatewayRunner)
runner._session_model_overrides = {}
runner.config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"chan_1": ChannelOverride(
model="channel/model",
provider="openrouter",
),
},
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="chan_1",
user_id="u1",
)
with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
"provider": "anthropic",
"api_key": "k",
"base_url": "https://api.anthropic.com",
"api_mode": "chat_completions",
}), \
patch(
"gateway.run._resolve_runtime_agent_kwargs_for_provider",
return_value={
"provider": "openrouter",
"api_key": "k2",
"base_url": "https://openrouter.ai/api/v1",
"api_mode": "chat_completions",
},
):
model, runtime = runner._resolve_session_agent_runtime(
source=source,
user_config={"model": {"default": "global/model"}},
)
assert model == "channel/model"
assert runtime["provider"] == "openrouter"
def test_session_model_beats_channel_override(self):
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"chan_1": ChannelOverride(model="channel/model"),
},
),
},
)
session_key = "agent:main:discord:channel:chan_1"
runner._session_model_overrides = {
session_key: {
"model": "session/model",
"provider": "anthropic",
},
}
source = SessionSource(
platform=Platform.DISCORD,
chat_id="chan_1",
chat_type="channel",
user_id="u1",
)
with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
"provider": "openrouter",
"api_key": "k",
"base_url": "https://openrouter.ai/api/v1",
"api_mode": "chat_completions",
}):
model, runtime = runner._resolve_session_agent_runtime(
source=source,
session_key=session_key,
)
assert model == "session/model"
assert runtime["provider"] == "anthropic"
def test_parent_channel_model_inherited_in_thread(self):
runner = object.__new__(GatewayRunner)
runner._session_model_overrides = {}
runner.config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"parent_chan": ChannelOverride(model="parent/model"),
},
),
},
)
source = SessionSource(
platform=Platform.DISCORD,
chat_id="thread_1",
chat_type="thread",
parent_chat_id="parent_chan",
user_id="u1",
)
with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
"provider": "anthropic",
"api_key": "k",
"base_url": "https://api.anthropic.com",
"api_mode": "chat_completions",
}):
model, _runtime = runner._resolve_session_agent_runtime(source=source)
assert model == "parent/model"

View file

@ -831,6 +831,31 @@ class TestLoadGatewayConfig:
assert config.always_log_local is False
def test_bridges_discord_channel_overrides_from_top_level_yaml(self, tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
config_path = hermes_home / "config.yaml"
config_path.write_text(
"discord:\n"
" channel_overrides:\n"
' "1234567890":\n'
" model: openrouter/healer-alpha\n"
" provider: openrouter\n"
" system_prompt: Daily news summarizer\n",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
config = load_gateway_config()
discord = config.platforms[Platform.DISCORD]
assert "1234567890" in discord.channel_overrides
ov = discord.channel_overrides["1234567890"]
assert ov.model == "openrouter/healer-alpha"
assert ov.provider == "openrouter"
assert ov.system_prompt == "Daily news summarizer"
def test_bridges_discord_channel_prompts_from_config_yaml(self, tmp_path, monkeypatch):
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()