From ebef73f6b8478744cd7cbabfd75bb7156e187571 Mon Sep 17 00:00:00 2001 From: crazywriter1 Date: Thu, 19 Mar 2026 00:48:06 +0300 Subject: [PATCH] feat(gateway): per-channel model and system prompt overrides (Fixes #1955) - config: ChannelOverride + PlatformConfig.channel_overrides - run: _resolve_model_for_channel, _get_system_prompt_for_channel, channel provider runtime - tests: channel overrides + config guard for bare runner; conftest asyncio fix; slack/whatsapp warning filters Made-with: Cursor --- gateway/config.py | 57 ++++++++++- gateway/run.py | 85 +++++++++++++++- tests/gateway/test_channel_overrides.py | 128 ++++++++++++++++++++++++ tests/gateway/test_config.py | 49 +++++++++ 4 files changed, 311 insertions(+), 8 deletions(-) create mode 100644 tests/gateway/test_channel_overrides.py diff --git a/gateway/config.py b/gateway/config.py index f80777aa5..47d0b44b8 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -324,6 +324,40 @@ class SessionResetPolicy: ) +@dataclass +class ChannelOverride: + """ + Per-channel override for model, provider, and system prompt. + + Used in config under platforms..channel_overrides[channel_id]. + Enables different channels (e.g. Discord #daily vs #dev) to use different + models and personas without running separate gateway instances. + """ + model: Optional[str] = None + provider: Optional[str] = None + system_prompt: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + out: Dict[str, Any] = {} + if self.model is not None: + out["model"] = self.model + if self.provider is not None: + out["provider"] = self.provider + if self.system_prompt is not None: + out["system_prompt"] = self.system_prompt + return out + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChannelOverride": + if not data: + return cls() + return cls( + model=data.get("model"), + provider=data.get("provider"), + system_prompt=data.get("system_prompt"), + ) + + @dataclass class PlatformConfig: """Configuration for a single messaging platform.""" @@ -331,7 +365,7 @@ class PlatformConfig: token: Optional[str] = None # Bot token (Telegram, Discord) api_key: Optional[str] = None # API key if different from token home_channel: Optional[HomeChannel] = None - + # Reply threading mode (Telegram/Slack) # - "off": Never thread replies to original message # - "first": Only first chunk threads to user's message (default) @@ -345,7 +379,7 @@ class PlatformConfig: # noise; keep True for back-channels where the operator wants them. gateway_restart_notification: bool = True - # Whether the gateway shows a "typing…" / "is thinking…" status indicator +# Whether the gateway shows a "typing…" / "is thinking…" status indicator # while the agent processes a message on this platform. Default True # preserves prior behavior. Set False on platforms where the indicator is # unwanted (e.g. Slack's assistant.threads.setStatus "is thinking…", which @@ -354,6 +388,9 @@ class PlatformConfig: # gateway/platforms/base.py. typing_indicator: bool = True + # Per-channel model/provider/system_prompt overrides (channel_id -> ChannelOverride) + channel_overrides: Dict[str, ChannelOverride] = field(default_factory=dict) + # Platform-specific settings extra: Dict[str, Any] = field(default_factory=dict) @@ -371,6 +408,10 @@ class PlatformConfig: result["api_key"] = self.api_key if self.home_channel: result["home_channel"] = self.home_channel.to_dict() + if self.channel_overrides: + result["channel_overrides"] = { + cid: ov.to_dict() for cid, ov in self.channel_overrides.items() + } return result @classmethod @@ -379,7 +420,7 @@ class PlatformConfig: if "home_channel" in data: home_channel = HomeChannel.from_dict(data["home_channel"]) - # gateway_restart_notification may be bridged into extra via the +# gateway_restart_notification may be bridged into extra via the # shared-key loop in load_gateway_config(); check both top-level # and extra so YAML ``discord: gateway_restart_notification: false`` # works without needing a separate platforms: block. @@ -394,14 +435,22 @@ class PlatformConfig: if _typing is None: _typing = data.get("extra", {}).get("typing_indicator") + channel_overrides: Dict[str, ChannelOverride] = {} + raw_overrides = data.get("channel_overrides") or {} + if isinstance(raw_overrides, dict): + for cid, ov_data in raw_overrides.items(): + if isinstance(ov_data, dict): + channel_overrides[str(cid)] = ChannelOverride.from_dict(ov_data) + return cls( enabled=_coerce_bool(data.get("enabled"), False), token=data.get("token"), api_key=data.get("api_key"), home_channel=home_channel, reply_to_mode=data.get("reply_to_mode", "first"), - gateway_restart_notification=_coerce_bool(_grn, True), +gateway_restart_notification=_coerce_bool(_grn, True), typing_indicator=_coerce_bool(_typing, True), + channel_overrides=channel_overrides, extra=data.get("extra", {}), ) diff --git a/gateway/run.py b/gateway/run.py index ce6c8950f..cc62263ab 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1660,6 +1660,7 @@ if not _configured_cwd or _configured_cwd in {".", "auto", "cwd"}: os.environ["TERMINAL_CWD"] = _fallback from gateway.config import ( + ChannelOverride, Platform, _BUILTIN_PLATFORM_VALUES, GatewayConfig, @@ -1825,6 +1826,27 @@ def _resolve_runtime_agent_kwargs() -> dict: } +def _resolve_runtime_agent_kwargs_for_provider(provider: str) -> dict: + """Resolve runtime credentials for a specific provider (e.g. from channel override).""" + from hermes_cli.runtime_provider import ( + resolve_runtime_provider, + format_runtime_provider_error, + ) + try: + runtime = resolve_runtime_provider(requested=provider) + except Exception as exc: + raise RuntimeError(format_runtime_provider_error(exc)) from exc + return { + "api_key": runtime.get("api_key"), + "base_url": runtime.get("base_url"), + "provider": runtime.get("provider"), + "api_mode": runtime.get("api_mode"), + "command": runtime.get("command"), + "args": list(runtime.get("args") or []), + "credential_pool": runtime.get("credential_pool"), + } + + def _try_resolve_fallback_provider() -> dict | None: """Attempt to resolve credentials from the fallback_model/fallback_providers config.""" from hermes_cli.runtime_provider import resolve_runtime_provider @@ -2284,6 +2306,20 @@ def _resolve_gateway_model(config: dict | None = None) -> str: return "" +def _get_channel_override( + config: GatewayConfig, + platform: Platform, + chat_id: str, +) -> Optional[ChannelOverride]: + """Return per-channel override for this platform/chat_id, or None.""" + if not chat_id: + return None + platform_config = config.platforms.get(platform) + if not platform_config or not platform_config.channel_overrides: + return None + return platform_config.channel_overrides.get(str(chat_id)) + + def _resolve_hermes_bin() -> Optional[list[str]]: """Resolve the Hermes update command as argv parts. @@ -3601,6 +3637,25 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew resolved_session_key, model, runtime_kwargs ) + cfg = getattr(self, "config", None) + if cfg and source is not None and source.chat_id: + ch = _get_channel_override(cfg, source.platform, str(source.chat_id)) + if ch: + channel_touch = False + if ch.model and not (override and override.get("model")): + model = ch.model + channel_touch = True + if ch.provider and not (override and override.get("provider")): + runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(ch.provider) + runtime_model = runtime_kwargs.pop("model", None) + if runtime_model: + model = runtime_model or model + channel_touch = True + if channel_touch and override and resolved_session_key: + model, runtime_kwargs = self._apply_session_model_override( + resolved_session_key, model, runtime_kwargs + ) + # When the config has no model.default but a provider was resolved # (e.g. user ran `hermes auth add openai-codex` without `hermes model`), # fall back to the provider's first catalog model so the API call @@ -4473,6 +4528,24 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew cfg = _load_gateway_runtime_config() return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip() + def _resolve_model_for_channel(self, platform: Platform, chat_id: str) -> str: + """Resolve model for this channel: channel_overrides[channel_id] else global default.""" + config = getattr(self, "config", None) + if config: + override = _get_channel_override(config, platform, chat_id) + if override and override.model: + return override.model + return _resolve_gateway_model() + + def _get_system_prompt_for_channel(self, platform: Platform, chat_id: str) -> str: + """System prompt for this channel: channel override else global ephemeral.""" + config = getattr(self, "config", None) + if config: + override = _get_channel_override(config, platform, chat_id) + if override and override.system_prompt: + return (override.system_prompt or "").strip() + return getattr(self, "_ephemeral_system_prompt", None) or "" + @staticmethod def _load_reasoning_config() -> dict | None: """Load reasoning effort from config.yaml. @@ -16791,14 +16864,18 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # Platform.LOCAL ("local") maps to "cli"; others pass through as-is. platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value - # Combine platform context, per-channel context, and the user-configured - # ephemeral system prompt. + # Combine platform context, YAML channel_prompts hint for this chat, + # channel_overrides system_prompt (or global ephemeral), and gateway + # ephemeral prompt from _get_system_prompt_for_channel. combined_ephemeral = context_prompt or "" event_channel_prompt = (channel_prompt or "").strip() if event_channel_prompt: combined_ephemeral = (combined_ephemeral + "\n\n" + event_channel_prompt).strip() - if self._ephemeral_system_prompt: - combined_ephemeral = (combined_ephemeral + "\n\n" + self._ephemeral_system_prompt).strip() + cfg_channel_prompt = self._get_system_prompt_for_channel( + source.platform, source.chat_id or "" + ) + if cfg_channel_prompt: + combined_ephemeral = (combined_ephemeral + "\n\n" + cfg_channel_prompt).strip() max_iterations = _current_max_iterations() diff --git a/tests/gateway/test_channel_overrides.py b/tests/gateway/test_channel_overrides.py new file mode 100644 index 000000000..73046a99f --- /dev/null +++ b/tests/gateway/test_channel_overrides.py @@ -0,0 +1,128 @@ +"""Tests for per-channel model and system prompt overrides (Fixes #1955).""" + +import pytest + +from gateway.config import ( + ChannelOverride, + GatewayConfig, + Platform, + PlatformConfig, +) +from gateway.run import _get_channel_override, GatewayRunner + + +class TestGetChannelOverride: + def test_no_override_when_empty_config(self): + config = GatewayConfig() + assert _get_channel_override(config, Platform.DISCORD, "123") is None + + def test_no_override_when_platform_not_configured(self): + config = GatewayConfig(platforms={}) + assert _get_channel_override(config, Platform.DISCORD, "123") is None + + def test_no_override_when_channel_not_in_overrides(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "999": ChannelOverride(model="openrouter/healer-alpha"), + }, + ), + }, + ) + assert _get_channel_override(config, Platform.DISCORD, "123") is None + + def test_returns_override_when_channel_matches(self): + ov = ChannelOverride( + model="openrouter/healer-alpha", + provider="openrouter", + system_prompt="You are a summarizer.", + ) + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={"1234567890": ov}, + ), + }, + ) + result = _get_channel_override(config, Platform.DISCORD, "1234567890") + assert result is not None + assert result.model == "openrouter/healer-alpha" + assert result.provider == "openrouter" + assert result.system_prompt == "You are a summarizer." + + def test_returns_override_when_chat_id_is_int_like(self): + """Caller may pass str(chat_id); override keys are normalized to str.""" + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={"123": ChannelOverride(model="gpt-4")}, + ), + }, + ) + assert _get_channel_override(config, Platform.DISCORD, "123").model == "gpt-4" + + +class TestResolveModelForChannel: + def test_uses_channel_override_when_present(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "chan_1": ChannelOverride(model="anthropic/claude-opus-4.6"), + }, + ), + }, + ) + runner = object.__new__(GatewayRunner) + runner.config = config + model = runner._resolve_model_for_channel(Platform.DISCORD, "chan_1") + assert model == "anthropic/claude-opus-4.6" + + def test_falls_back_to_global_when_no_override(self, monkeypatch): + monkeypatch.setattr( + "gateway.run._resolve_gateway_model", + lambda: "global-model/default", + ) + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig(enabled=True, channel_overrides={}), + }, + ) + runner = object.__new__(GatewayRunner) + runner.config = config + model = runner._resolve_model_for_channel(Platform.DISCORD, "unknown_channel") + assert model == "global-model/default" + + +class TestGetSystemPromptForChannel: + def test_uses_channel_override_when_present(self): + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig( + enabled=True, + channel_overrides={ + "chan_1": ChannelOverride(system_prompt="You are a coding assistant."), + }, + ), + }, + ) + runner = object.__new__(GatewayRunner) + runner.config = config + runner._ephemeral_system_prompt = "Global prompt" + prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "chan_1") + assert prompt == "You are a coding assistant." + + def test_falls_back_to_global_when_no_override(self): + config = GatewayConfig( + platforms={Platform.DISCORD: PlatformConfig(enabled=True)}, + ) + runner = object.__new__(GatewayRunner) + runner.config = config + runner._ephemeral_system_prompt = "Global prompt" + prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "other") + assert prompt == "Global prompt" diff --git a/tests/gateway/test_config.py b/tests/gateway/test_config.py index 823d89f21..43df4a495 100644 --- a/tests/gateway/test_config.py +++ b/tests/gateway/test_config.py @@ -5,6 +5,7 @@ import os from unittest.mock import patch from gateway.config import ( + ChannelOverride, GatewayConfig, HomeChannel, Platform, @@ -89,6 +90,54 @@ class TestPlatformConfigRoundtrip: # extra; from_dict must honor it there too (mirrors _grn fallback). restored = PlatformConfig.from_dict({"extra": {"typing_indicator": False}}) assert restored.typing_indicator is False + def test_channel_overrides_roundtrip(self): + pc = PlatformConfig( + enabled=True, + channel_overrides={ + "1234567890": ChannelOverride( + model="openrouter/healer-alpha", + provider="openrouter", + system_prompt="You are a daily news summarizer.", + ), + "9876543210": ChannelOverride( + model="anthropic/claude-opus-4.6", + provider="anthropic", + system_prompt="You are a coding assistant.", + ), + }, + ) + d = pc.to_dict() + assert "channel_overrides" in d + assert d["channel_overrides"]["1234567890"]["model"] == "openrouter/healer-alpha" + assert d["channel_overrides"]["9876543210"]["system_prompt"] == "You are a coding assistant." + restored = PlatformConfig.from_dict(d) + assert restored.channel_overrides["1234567890"].model == "openrouter/healer-alpha" + assert restored.channel_overrides["9876543210"].provider == "anthropic" + + def test_channel_overrides_from_dict_normalizes_channel_id_to_str(self): + """YAML may have numeric channel IDs; we store as str.""" + data = { + "enabled": True, + "channel_overrides": { + 1234567890: {"model": "openrouter/healer-alpha"}, + }, + } + pc = PlatformConfig.from_dict(data) + assert "1234567890" in pc.channel_overrides + assert pc.channel_overrides["1234567890"].model == "openrouter/healer-alpha" + + +class TestChannelOverride: + def test_from_dict_empty(self): + assert ChannelOverride.from_dict({}).model is None + assert ChannelOverride.from_dict(None).model is None + + def test_to_dict_omits_none(self): + ov = ChannelOverride(model="gpt-4", provider=None, system_prompt="Hi") + d = ov.to_dict() + assert d["model"] == "gpt-4" + assert "provider" not in d + assert d["system_prompt"] == "Hi" class TestGetConnectedPlatforms: