feat(gateway): per-channel model and system prompt overrides (Fixes #1955)

- config: ChannelOverride + PlatformConfig.channel_overrides

- run: _resolve_model_for_channel, _get_system_prompt_for_channel, channel provider runtime

- tests: channel overrides + config guard for bare runner; conftest asyncio fix; slack/whatsapp warning filters

Made-with: Cursor
This commit is contained in:
crazywriter1 2026-03-19 00:48:06 +03:00 committed by Teknium
parent 902b0b70e4
commit ebef73f6b8
4 changed files with 311 additions and 8 deletions

View file

@ -324,6 +324,40 @@ class SessionResetPolicy:
)
@dataclass
class ChannelOverride:
"""
Per-channel override for model, provider, and system prompt.
Used in config under platforms.<name>.channel_overrides[channel_id].
Enables different channels (e.g. Discord #daily vs #dev) to use different
models and personas without running separate gateway instances.
"""
model: Optional[str] = None
provider: Optional[str] = None
system_prompt: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.model is not None:
out["model"] = self.model
if self.provider is not None:
out["provider"] = self.provider
if self.system_prompt is not None:
out["system_prompt"] = self.system_prompt
return out
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChannelOverride":
if not data:
return cls()
return cls(
model=data.get("model"),
provider=data.get("provider"),
system_prompt=data.get("system_prompt"),
)
@dataclass
class PlatformConfig:
"""Configuration for a single messaging platform."""
@ -331,7 +365,7 @@ class PlatformConfig:
token: Optional[str] = None # Bot token (Telegram, Discord)
api_key: Optional[str] = None # API key if different from token
home_channel: Optional[HomeChannel] = None
# Reply threading mode (Telegram/Slack)
# - "off": Never thread replies to original message
# - "first": Only first chunk threads to user's message (default)
@ -345,7 +379,7 @@ class PlatformConfig:
# noise; keep True for back-channels where the operator wants them.
gateway_restart_notification: bool = True
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
# while the agent processes a message on this platform. Default True
# preserves prior behavior. Set False on platforms where the indicator is
# unwanted (e.g. Slack's assistant.threads.setStatus "is thinking…", which
@ -354,6 +388,9 @@ class PlatformConfig:
# gateway/platforms/base.py.
typing_indicator: bool = True
# Per-channel model/provider/system_prompt overrides (channel_id -> ChannelOverride)
channel_overrides: Dict[str, ChannelOverride] = field(default_factory=dict)
# Platform-specific settings
extra: Dict[str, Any] = field(default_factory=dict)
@ -371,6 +408,10 @@ class PlatformConfig:
result["api_key"] = self.api_key
if self.home_channel:
result["home_channel"] = self.home_channel.to_dict()
if self.channel_overrides:
result["channel_overrides"] = {
cid: ov.to_dict() for cid, ov in self.channel_overrides.items()
}
return result
@classmethod
@ -379,7 +420,7 @@ class PlatformConfig:
if "home_channel" in data:
home_channel = HomeChannel.from_dict(data["home_channel"])
# gateway_restart_notification may be bridged into extra via the
# gateway_restart_notification may be bridged into extra via the
# shared-key loop in load_gateway_config(); check both top-level
# and extra so YAML ``discord: gateway_restart_notification: false``
# works without needing a separate platforms: block.
@ -394,14 +435,22 @@ class PlatformConfig:
if _typing is None:
_typing = data.get("extra", {}).get("typing_indicator")
channel_overrides: Dict[str, ChannelOverride] = {}
raw_overrides = data.get("channel_overrides") or {}
if isinstance(raw_overrides, dict):
for cid, ov_data in raw_overrides.items():
if isinstance(ov_data, dict):
channel_overrides[str(cid)] = ChannelOverride.from_dict(ov_data)
return cls(
enabled=_coerce_bool(data.get("enabled"), False),
token=data.get("token"),
api_key=data.get("api_key"),
home_channel=home_channel,
reply_to_mode=data.get("reply_to_mode", "first"),
gateway_restart_notification=_coerce_bool(_grn, True),
gateway_restart_notification=_coerce_bool(_grn, True),
typing_indicator=_coerce_bool(_typing, True),
channel_overrides=channel_overrides,
extra=data.get("extra", {}),
)

View file

@ -1660,6 +1660,7 @@ if not _configured_cwd or _configured_cwd in {".", "auto", "cwd"}:
os.environ["TERMINAL_CWD"] = _fallback
from gateway.config import (
ChannelOverride,
Platform,
_BUILTIN_PLATFORM_VALUES,
GatewayConfig,
@ -1825,6 +1826,27 @@ def _resolve_runtime_agent_kwargs() -> dict:
}
def _resolve_runtime_agent_kwargs_for_provider(provider: str) -> dict:
"""Resolve runtime credentials for a specific provider (e.g. from channel override)."""
from hermes_cli.runtime_provider import (
resolve_runtime_provider,
format_runtime_provider_error,
)
try:
runtime = resolve_runtime_provider(requested=provider)
except Exception as exc:
raise RuntimeError(format_runtime_provider_error(exc)) from exc
return {
"api_key": runtime.get("api_key"),
"base_url": runtime.get("base_url"),
"provider": runtime.get("provider"),
"api_mode": runtime.get("api_mode"),
"command": runtime.get("command"),
"args": list(runtime.get("args") or []),
"credential_pool": runtime.get("credential_pool"),
}
def _try_resolve_fallback_provider() -> dict | None:
"""Attempt to resolve credentials from the fallback_model/fallback_providers config."""
from hermes_cli.runtime_provider import resolve_runtime_provider
@ -2284,6 +2306,20 @@ def _resolve_gateway_model(config: dict | None = None) -> str:
return ""
def _get_channel_override(
config: GatewayConfig,
platform: Platform,
chat_id: str,
) -> Optional[ChannelOverride]:
"""Return per-channel override for this platform/chat_id, or None."""
if not chat_id:
return None
platform_config = config.platforms.get(platform)
if not platform_config or not platform_config.channel_overrides:
return None
return platform_config.channel_overrides.get(str(chat_id))
def _resolve_hermes_bin() -> Optional[list[str]]:
"""Resolve the Hermes update command as argv parts.
@ -3601,6 +3637,25 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
resolved_session_key, model, runtime_kwargs
)
cfg = getattr(self, "config", None)
if cfg and source is not None and source.chat_id:
ch = _get_channel_override(cfg, source.platform, str(source.chat_id))
if ch:
channel_touch = False
if ch.model and not (override and override.get("model")):
model = ch.model
channel_touch = True
if ch.provider and not (override and override.get("provider")):
runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(ch.provider)
runtime_model = runtime_kwargs.pop("model", None)
if runtime_model:
model = runtime_model or model
channel_touch = True
if channel_touch and override and resolved_session_key:
model, runtime_kwargs = self._apply_session_model_override(
resolved_session_key, model, runtime_kwargs
)
# When the config has no model.default but a provider was resolved
# (e.g. user ran `hermes auth add openai-codex` without `hermes model`),
# fall back to the provider's first catalog model so the API call
@ -4473,6 +4528,24 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
cfg = _load_gateway_runtime_config()
return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip()
def _resolve_model_for_channel(self, platform: Platform, chat_id: str) -> str:
"""Resolve model for this channel: channel_overrides[channel_id] else global default."""
config = getattr(self, "config", None)
if config:
override = _get_channel_override(config, platform, chat_id)
if override and override.model:
return override.model
return _resolve_gateway_model()
def _get_system_prompt_for_channel(self, platform: Platform, chat_id: str) -> str:
"""System prompt for this channel: channel override else global ephemeral."""
config = getattr(self, "config", None)
if config:
override = _get_channel_override(config, platform, chat_id)
if override and override.system_prompt:
return (override.system_prompt or "").strip()
return getattr(self, "_ephemeral_system_prompt", None) or ""
@staticmethod
def _load_reasoning_config() -> dict | None:
"""Load reasoning effort from config.yaml.
@ -16791,14 +16864,18 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value
# Combine platform context, per-channel context, and the user-configured
# ephemeral system prompt.
# Combine platform context, YAML channel_prompts hint for this chat,
# channel_overrides system_prompt (or global ephemeral), and gateway
# ephemeral prompt from _get_system_prompt_for_channel.
combined_ephemeral = context_prompt or ""
event_channel_prompt = (channel_prompt or "").strip()
if event_channel_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + event_channel_prompt).strip()
if self._ephemeral_system_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + self._ephemeral_system_prompt).strip()
cfg_channel_prompt = self._get_system_prompt_for_channel(
source.platform, source.chat_id or ""
)
if cfg_channel_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + cfg_channel_prompt).strip()
max_iterations = _current_max_iterations()

View file

@ -0,0 +1,128 @@
"""Tests for per-channel model and system prompt overrides (Fixes #1955)."""
import pytest
from gateway.config import (
ChannelOverride,
GatewayConfig,
Platform,
PlatformConfig,
)
from gateway.run import _get_channel_override, GatewayRunner
class TestGetChannelOverride:
def test_no_override_when_empty_config(self):
config = GatewayConfig()
assert _get_channel_override(config, Platform.DISCORD, "123") is None
def test_no_override_when_platform_not_configured(self):
config = GatewayConfig(platforms={})
assert _get_channel_override(config, Platform.DISCORD, "123") is None
def test_no_override_when_channel_not_in_overrides(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"999": ChannelOverride(model="openrouter/healer-alpha"),
},
),
},
)
assert _get_channel_override(config, Platform.DISCORD, "123") is None
def test_returns_override_when_channel_matches(self):
ov = ChannelOverride(
model="openrouter/healer-alpha",
provider="openrouter",
system_prompt="You are a summarizer.",
)
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={"1234567890": ov},
),
},
)
result = _get_channel_override(config, Platform.DISCORD, "1234567890")
assert result is not None
assert result.model == "openrouter/healer-alpha"
assert result.provider == "openrouter"
assert result.system_prompt == "You are a summarizer."
def test_returns_override_when_chat_id_is_int_like(self):
"""Caller may pass str(chat_id); override keys are normalized to str."""
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={"123": ChannelOverride(model="gpt-4")},
),
},
)
assert _get_channel_override(config, Platform.DISCORD, "123").model == "gpt-4"
class TestResolveModelForChannel:
def test_uses_channel_override_when_present(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"chan_1": ChannelOverride(model="anthropic/claude-opus-4.6"),
},
),
},
)
runner = object.__new__(GatewayRunner)
runner.config = config
model = runner._resolve_model_for_channel(Platform.DISCORD, "chan_1")
assert model == "anthropic/claude-opus-4.6"
def test_falls_back_to_global_when_no_override(self, monkeypatch):
monkeypatch.setattr(
"gateway.run._resolve_gateway_model",
lambda: "global-model/default",
)
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(enabled=True, channel_overrides={}),
},
)
runner = object.__new__(GatewayRunner)
runner.config = config
model = runner._resolve_model_for_channel(Platform.DISCORD, "unknown_channel")
assert model == "global-model/default"
class TestGetSystemPromptForChannel:
def test_uses_channel_override_when_present(self):
config = GatewayConfig(
platforms={
Platform.DISCORD: PlatformConfig(
enabled=True,
channel_overrides={
"chan_1": ChannelOverride(system_prompt="You are a coding assistant."),
},
),
},
)
runner = object.__new__(GatewayRunner)
runner.config = config
runner._ephemeral_system_prompt = "Global prompt"
prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "chan_1")
assert prompt == "You are a coding assistant."
def test_falls_back_to_global_when_no_override(self):
config = GatewayConfig(
platforms={Platform.DISCORD: PlatformConfig(enabled=True)},
)
runner = object.__new__(GatewayRunner)
runner.config = config
runner._ephemeral_system_prompt = "Global prompt"
prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "other")
assert prompt == "Global prompt"

View file

@ -5,6 +5,7 @@ import os
from unittest.mock import patch
from gateway.config import (
ChannelOverride,
GatewayConfig,
HomeChannel,
Platform,
@ -89,6 +90,54 @@ class TestPlatformConfigRoundtrip:
# extra; from_dict must honor it there too (mirrors _grn fallback).
restored = PlatformConfig.from_dict({"extra": {"typing_indicator": False}})
assert restored.typing_indicator is False
def test_channel_overrides_roundtrip(self):
pc = PlatformConfig(
enabled=True,
channel_overrides={
"1234567890": ChannelOverride(
model="openrouter/healer-alpha",
provider="openrouter",
system_prompt="You are a daily news summarizer.",
),
"9876543210": ChannelOverride(
model="anthropic/claude-opus-4.6",
provider="anthropic",
system_prompt="You are a coding assistant.",
),
},
)
d = pc.to_dict()
assert "channel_overrides" in d
assert d["channel_overrides"]["1234567890"]["model"] == "openrouter/healer-alpha"
assert d["channel_overrides"]["9876543210"]["system_prompt"] == "You are a coding assistant."
restored = PlatformConfig.from_dict(d)
assert restored.channel_overrides["1234567890"].model == "openrouter/healer-alpha"
assert restored.channel_overrides["9876543210"].provider == "anthropic"
def test_channel_overrides_from_dict_normalizes_channel_id_to_str(self):
"""YAML may have numeric channel IDs; we store as str."""
data = {
"enabled": True,
"channel_overrides": {
1234567890: {"model": "openrouter/healer-alpha"},
},
}
pc = PlatformConfig.from_dict(data)
assert "1234567890" in pc.channel_overrides
assert pc.channel_overrides["1234567890"].model == "openrouter/healer-alpha"
class TestChannelOverride:
def test_from_dict_empty(self):
assert ChannelOverride.from_dict({}).model is None
assert ChannelOverride.from_dict(None).model is None
def test_to_dict_omits_none(self):
ov = ChannelOverride(model="gpt-4", provider=None, system_prompt="Hi")
d = ov.to_dict()
assert d["model"] == "gpt-4"
assert "provider" not in d
assert d["system_prompt"] == "Hi"
class TestGetConnectedPlatforms: