From 0010c14e66cffbed84dfbf1b4a15093f0f8cc76d Mon Sep 17 00:00:00 2001 From: crazywriter1 Date: Sun, 17 May 2026 16:31:02 +0300 Subject: [PATCH] feat(gateway): per-channel model and system prompt overrides (Fixes #1955) - ChannelOverride + channel_overrides on PlatformConfig - Resolve model/runtime: session /model, then channel_overrides, then global - Thread/parent channel lookup; bridge discord.channel_overrides from YAML - Drop unrelated test and delegate_tool changes from PR scope --- gateway/config.py | 20 ++- gateway/run.py | 155 ++++++++++++++++----- tests/gateway/test_channel_overrides.py | 176 +++++++++++++++++++++++- tests/gateway/test_config.py | 25 ++++ 4 files changed, 336 insertions(+), 40 deletions(-) diff --git a/gateway/config.py b/gateway/config.py index 47d0b44b8..7693cef1c 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -379,7 +379,7 @@ class PlatformConfig: # noise; keep True for back-channels where the operator wants them. gateway_restart_notification: bool = True -# Whether the gateway shows a "typing…" / "is thinking…" status indicator + # Whether the gateway shows a "typing…" / "is thinking…" status indicator # while the agent processes a message on this platform. Default True # preserves prior behavior. Set False on platforms where the indicator is # unwanted (e.g. Slack's assistant.threads.setStatus "is thinking…", which @@ -420,7 +420,7 @@ class PlatformConfig: if "home_channel" in data: home_channel = HomeChannel.from_dict(data["home_channel"]) -# gateway_restart_notification may be bridged into extra via the + # gateway_restart_notification may be bridged into extra via the # shared-key loop in load_gateway_config(); check both top-level # and extra so YAML ``discord: gateway_restart_notification: false`` # works without needing a separate platforms: block. @@ -448,7 +448,7 @@ class PlatformConfig: api_key=data.get("api_key"), home_channel=home_channel, reply_to_mode=data.get("reply_to_mode", "first"), -gateway_restart_notification=_coerce_bool(_grn, True), + gateway_restart_notification=_coerce_bool(_grn, True), typing_indicator=_coerce_bool(_typing, True), channel_overrides=channel_overrides, extra=data.get("extra", {}), @@ -1103,8 +1103,20 @@ def load_gateway_config() -> GatewayConfig: bridged["gateway_restart_notification"] = platform_cfg["gateway_restart_notification"] if "typing_indicator" in platform_cfg: bridged["typing_indicator"] = platform_cfg["typing_indicator"] + has_channel_overrides = "channel_overrides" in platform_cfg + if has_channel_overrides: + raw_overrides = platform_cfg.get("channel_overrides") + if isinstance(raw_overrides, dict): + plat_data, _extra = _ensure_platform_extra_dict( + platforms_data, plat.value + ) + plat_data["channel_overrides"] = { + str(cid): ov_data + for cid, ov_data in raw_overrides.items() + if isinstance(ov_data, dict) + } enabled_was_explicit = _cfg_toplevel and "enabled" in platform_cfg - if not bridged and not enabled_was_explicit: + if not bridged and not enabled_was_explicit and not has_channel_overrides: continue plat_data, extra = _ensure_platform_extra_dict(platforms_data, plat.value) if enabled_was_explicit: diff --git a/gateway/run.py b/gateway/run.py index cc62263ab..4cdf1f0b9 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2306,18 +2306,54 @@ def _resolve_gateway_model(config: dict | None = None) -> str: return "" +def _channel_override_lookup_keys( + chat_id: str, + *, + thread_id: Optional[str] = None, + parent_id: Optional[str] = None, +) -> list[str]: + """Ordered, de-duplicated keys for ``channel_overrides`` lookup. + + Matches ``resolve_channel_prompt`` semantics: exact thread/channel id first, + then parent channel/forum id (Discord threads inherit parent overrides). + """ + keys: list[str] = [] + seen: set[str] = set() + for key in (chat_id, thread_id, parent_id): + if not key: + continue + sk = str(key) + if sk in seen: + continue + seen.add(sk) + keys.append(sk) + return keys + + def _get_channel_override( config: GatewayConfig, platform: Platform, chat_id: str, + *, + thread_id: Optional[str] = None, + parent_id: Optional[str] = None, ) -> Optional[ChannelOverride]: - """Return per-channel override for this platform/chat_id, or None.""" - if not chat_id: - return None + """Return per-channel override for this platform/chat_id, or None. + + Looks up ``channel_overrides`` by ``chat_id``, then ``thread_id``, then + ``parent_id`` (forum threads / child channels inherit the parent entry). + """ platform_config = config.platforms.get(platform) if not platform_config or not platform_config.channel_overrides: return None - return platform_config.channel_overrides.get(str(chat_id)) + overrides = platform_config.channel_overrides + for key in _channel_override_lookup_keys( + chat_id, thread_id=thread_id, parent_id=parent_id + ): + ov = overrides.get(key) + if ov is not None: + return ov + return None def _resolve_hermes_bin() -> Optional[list[str]]: @@ -3579,11 +3615,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew session_key: Optional[str] = None, user_config: Optional[dict] = None, ) -> tuple[str, dict]: - """Resolve model/runtime for a session, honoring session-scoped /model overrides. + """Resolve model/runtime for a session. - If the session override already contains a complete provider bundle - (provider/api_key/base_url/api_mode), prefer it directly instead of - resolving fresh global runtime state first. + Priority (highest first): session ``/model`` → ``channel_overrides`` → + global config/env (``_resolve_gateway_model(user_config)`` and default + provider resolution). """ resolved_session_key = session_key if not resolved_session_key and source is not None: @@ -3632,30 +3668,43 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew runtime_model, ) model = runtime_model + + cfg = getattr(self, "config", None) + if cfg and source is not None: + chat_id = str(source.chat_id) if source.chat_id else "" + thread_id = ( + str(source.thread_id) if getattr(source, "thread_id", None) else None + ) + parent_id = ( + str(source.parent_chat_id) + if getattr(source, "parent_chat_id", None) + else None + ) + ch = _get_channel_override( + cfg, + source.platform, + chat_id, + thread_id=thread_id, + parent_id=parent_id, + ) + if ch: + if ch.model: + model = ch.model + if ch.provider: + runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider( + ch.provider + ) + ch_runtime_model = runtime_kwargs.pop("model", None) + # Only adopt the provider's bundled model when the override + # did not specify an explicit model. + if ch_runtime_model and not ch.model: + model = ch_runtime_model + if override and resolved_session_key: model, runtime_kwargs = self._apply_session_model_override( resolved_session_key, model, runtime_kwargs ) - cfg = getattr(self, "config", None) - if cfg and source is not None and source.chat_id: - ch = _get_channel_override(cfg, source.platform, str(source.chat_id)) - if ch: - channel_touch = False - if ch.model and not (override and override.get("model")): - model = ch.model - channel_touch = True - if ch.provider and not (override and override.get("provider")): - runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(ch.provider) - runtime_model = runtime_kwargs.pop("model", None) - if runtime_model: - model = runtime_model or model - channel_touch = True - if channel_touch and override and resolved_session_key: - model, runtime_kwargs = self._apply_session_model_override( - resolved_session_key, model, runtime_kwargs - ) - # When the config has no model.default but a provider was resolved # (e.g. user ran `hermes auth add openai-codex` without `hermes model`), # fall back to the provider's first catalog model so the API call @@ -4528,20 +4577,53 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew cfg = _load_gateway_runtime_config() return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip() - def _resolve_model_for_channel(self, platform: Platform, chat_id: str) -> str: - """Resolve model for this channel: channel_overrides[channel_id] else global default.""" + def _resolve_model_for_channel( + self, + platform: Platform, + chat_id: str, + *, + user_config: Optional[dict] = None, + thread_id: Optional[str] = None, + parent_id: Optional[str] = None, + ) -> str: + """Resolve model for this channel: channel_overrides else global default.""" config = getattr(self, "config", None) if config: - override = _get_channel_override(config, platform, chat_id) + override = _get_channel_override( + config, + platform, + chat_id, + thread_id=thread_id, + parent_id=parent_id, + ) if override and override.model: return override.model - return _resolve_gateway_model() + return _resolve_gateway_model(user_config) - def _get_system_prompt_for_channel(self, platform: Platform, chat_id: str) -> str: - """System prompt for this channel: channel override else global ephemeral.""" + def _get_system_prompt_for_channel( + self, + platform: Platform, + chat_id: str, + *, + thread_id: Optional[str] = None, + parent_id: Optional[str] = None, + ) -> str: + """Ephemeral system prompt for this channel/thread. + + Uses ``channel_overrides`` when set, else the global gateway prompt. + Legacy ``channel_prompts`` are applied separately via ``event.channel_prompt`` + in ``run_sync`` (adapter ``resolve_channel_prompt``), so they are not + duplicated here. + """ config = getattr(self, "config", None) if config: - override = _get_channel_override(config, platform, chat_id) + override = _get_channel_override( + config, + platform, + chat_id, + thread_id=thread_id, + parent_id=parent_id, + ) if override and override.system_prompt: return (override.system_prompt or "").strip() return getattr(self, "_ephemeral_system_prompt", None) or "" @@ -16872,7 +16954,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if event_channel_prompt: combined_ephemeral = (combined_ephemeral + "\n\n" + event_channel_prompt).strip() cfg_channel_prompt = self._get_system_prompt_for_channel( - source.platform, source.chat_id or "" + source.platform, + source.chat_id or "", + thread_id=getattr(source, "thread_id", None), + parent_id=getattr(source, "parent_chat_id", None), ) if cfg_channel_prompt: combined_ephemeral = (combined_ephemeral + "\n\n" + cfg_channel_prompt).strip() diff --git a/tests/gateway/test_channel_overrides.py b/tests/gateway/test_channel_overrides.py index 73046a99f..9ad288705 100644 --- a/tests/gateway/test_channel_overrides.py +++ b/tests/gateway/test_channel_overrides.py @@ -1,5 +1,7 @@ """Tests for per-channel model and system prompt overrides (Fixes #1955).""" +from unittest.mock import patch + import pytest from gateway.config import ( @@ -9,6 +11,7 @@ from gateway.config import ( PlatformConfig, ) from gateway.run import _get_channel_override, GatewayRunner +from gateway.session import SessionSource class TestGetChannelOverride: @@ -65,6 +68,60 @@ class TestGetChannelOverride: ) assert _get_channel_override(config, Platform.DISCORD, "123").model == "gpt-4" + def test_thread_id_lookup_when_chat_id_misses(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "thread_99": ChannelOverride(model="topic-model"), + }, + ), + }, + ) + result = _get_channel_override( + config, Platform.DISCORD, "parent_chan", thread_id="thread_99" + ) + assert result is not None + assert result.model == "topic-model" + + def test_parent_id_fallback_when_thread_has_no_entry(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "parent_chan": ChannelOverride(model="parent-model"), + }, + ), + }, + ) + result = _get_channel_override( + config, + Platform.DISCORD, + "thread_only", + parent_id="parent_chan", + ) + assert result is not None + assert result.model == "parent-model" + + def test_exact_thread_overrides_parent(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "thread_1": ChannelOverride(model="thread-model"), + "parent_chan": ChannelOverride(model="parent-model"), + }, + ), + }, + ) + result = _get_channel_override( + config, Platform.DISCORD, "thread_1", parent_id="parent_chan" + ) + assert result.model == "thread-model" + class TestResolveModelForChannel: def test_uses_channel_override_when_present(self): @@ -86,7 +143,7 @@ class TestResolveModelForChannel: def test_falls_back_to_global_when_no_override(self, monkeypatch): monkeypatch.setattr( "gateway.run._resolve_gateway_model", - lambda: "global-model/default", + lambda _cfg=None: "global-model/default", ) config = GatewayConfig( platforms={ @@ -126,3 +183,120 @@ class TestGetSystemPromptForChannel: runner._ephemeral_system_prompt = "Global prompt" prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "other") assert prompt == "Global prompt" + + +class TestResolveSessionAgentRuntimePriority: + """Model/runtime priority: session /model → channel_overrides → global.""" + + def test_channel_override_beats_global(self): + runner = object.__new__(GatewayRunner) + runner._session_model_overrides = {} + runner.config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "chan_1": ChannelOverride( + model="channel/model", + provider="openrouter", + ), + }, + ), + }, + ) + source = SessionSource( + platform=Platform.DISCORD, + chat_id="chan_1", + user_id="u1", + ) + with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \ + patch("gateway.run._resolve_runtime_agent_kwargs", return_value={ + "provider": "anthropic", + "api_key": "k", + "base_url": "https://api.anthropic.com", + "api_mode": "chat_completions", + }), \ + patch( + "gateway.run._resolve_runtime_agent_kwargs_for_provider", + return_value={ + "provider": "openrouter", + "api_key": "k2", + "base_url": "https://openrouter.ai/api/v1", + "api_mode": "chat_completions", + }, + ): + model, runtime = runner._resolve_session_agent_runtime( + source=source, + user_config={"model": {"default": "global/model"}}, + ) + assert model == "channel/model" + assert runtime["provider"] == "openrouter" + + def test_session_model_beats_channel_override(self): + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "chan_1": ChannelOverride(model="channel/model"), + }, + ), + }, + ) + session_key = "agent:main:discord:channel:chan_1" + runner._session_model_overrides = { + session_key: { + "model": "session/model", + "provider": "anthropic", + }, + } + source = SessionSource( + platform=Platform.DISCORD, + chat_id="chan_1", + chat_type="channel", + user_id="u1", + ) + with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \ + patch("gateway.run._resolve_runtime_agent_kwargs", return_value={ + "provider": "openrouter", + "api_key": "k", + "base_url": "https://openrouter.ai/api/v1", + "api_mode": "chat_completions", + }): + model, runtime = runner._resolve_session_agent_runtime( + source=source, + session_key=session_key, + ) + assert model == "session/model" + assert runtime["provider"] == "anthropic" + + def test_parent_channel_model_inherited_in_thread(self): + runner = object.__new__(GatewayRunner) + runner._session_model_overrides = {} + runner.config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "parent_chan": ChannelOverride(model="parent/model"), + }, + ), + }, + ) + source = SessionSource( + platform=Platform.DISCORD, + chat_id="thread_1", + chat_type="thread", + parent_chat_id="parent_chan", + user_id="u1", + ) + with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \ + patch("gateway.run._resolve_runtime_agent_kwargs", return_value={ + "provider": "anthropic", + "api_key": "k", + "base_url": "https://api.anthropic.com", + "api_mode": "chat_completions", + }): + model, _runtime = runner._resolve_session_agent_runtime(source=source) + assert model == "parent/model" diff --git a/tests/gateway/test_config.py b/tests/gateway/test_config.py index 43df4a495..3f787403b 100644 --- a/tests/gateway/test_config.py +++ b/tests/gateway/test_config.py @@ -831,6 +831,31 @@ class TestLoadGatewayConfig: assert config.always_log_local is False + def test_bridges_discord_channel_overrides_from_top_level_yaml(self, tmp_path, monkeypatch): + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + config_path = hermes_home / "config.yaml" + config_path.write_text( + "discord:\n" + " channel_overrides:\n" + ' "1234567890":\n' + " model: openrouter/healer-alpha\n" + " provider: openrouter\n" + " system_prompt: Daily news summarizer\n", + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + config = load_gateway_config() + + discord = config.platforms[Platform.DISCORD] + assert "1234567890" in discord.channel_overrides + ov = discord.channel_overrides["1234567890"] + assert ov.model == "openrouter/healer-alpha" + assert ov.provider == "openrouter" + assert ov.system_prompt == "Daily news summarizer" + def test_bridges_discord_channel_prompts_from_config_yaml(self, tmp_path, monkeypatch): hermes_home = tmp_path / ".hermes" hermes_home.mkdir()