feat(gateway): per-channel model and system prompt overrides (Fixes #1955)
- config: ChannelOverride + PlatformConfig.channel_overrides - run: _resolve_model_for_channel, _get_system_prompt_for_channel, channel provider runtime - tests: channel overrides + config guard for bare runner; conftest asyncio fix; slack/whatsapp warning filters Made-with: Cursor
This commit is contained in:
parent
902b0b70e4
commit
ebef73f6b8
4 changed files with 311 additions and 8 deletions
|
|
@ -324,6 +324,40 @@ class SessionResetPolicy:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChannelOverride:
|
||||
"""
|
||||
Per-channel override for model, provider, and system prompt.
|
||||
|
||||
Used in config under platforms.<name>.channel_overrides[channel_id].
|
||||
Enables different channels (e.g. Discord #daily vs #dev) to use different
|
||||
models and personas without running separate gateway instances.
|
||||
"""
|
||||
model: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {}
|
||||
if self.model is not None:
|
||||
out["model"] = self.model
|
||||
if self.provider is not None:
|
||||
out["provider"] = self.provider
|
||||
if self.system_prompt is not None:
|
||||
out["system_prompt"] = self.system_prompt
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ChannelOverride":
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
model=data.get("model"),
|
||||
provider=data.get("provider"),
|
||||
system_prompt=data.get("system_prompt"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""Configuration for a single messaging platform."""
|
||||
|
|
@ -331,7 +365,7 @@ class PlatformConfig:
|
|||
token: Optional[str] = None # Bot token (Telegram, Discord)
|
||||
api_key: Optional[str] = None # API key if different from token
|
||||
home_channel: Optional[HomeChannel] = None
|
||||
|
||||
|
||||
# Reply threading mode (Telegram/Slack)
|
||||
# - "off": Never thread replies to original message
|
||||
# - "first": Only first chunk threads to user's message (default)
|
||||
|
|
@ -345,7 +379,7 @@ class PlatformConfig:
|
|||
# noise; keep True for back-channels where the operator wants them.
|
||||
gateway_restart_notification: bool = True
|
||||
|
||||
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
|
||||
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
|
||||
# while the agent processes a message on this platform. Default True
|
||||
# preserves prior behavior. Set False on platforms where the indicator is
|
||||
# unwanted (e.g. Slack's assistant.threads.setStatus "is thinking…", which
|
||||
|
|
@ -354,6 +388,9 @@ class PlatformConfig:
|
|||
# gateway/platforms/base.py.
|
||||
typing_indicator: bool = True
|
||||
|
||||
# Per-channel model/provider/system_prompt overrides (channel_id -> ChannelOverride)
|
||||
channel_overrides: Dict[str, ChannelOverride] = field(default_factory=dict)
|
||||
|
||||
# Platform-specific settings
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
|
@ -371,6 +408,10 @@ class PlatformConfig:
|
|||
result["api_key"] = self.api_key
|
||||
if self.home_channel:
|
||||
result["home_channel"] = self.home_channel.to_dict()
|
||||
if self.channel_overrides:
|
||||
result["channel_overrides"] = {
|
||||
cid: ov.to_dict() for cid, ov in self.channel_overrides.items()
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
|
|
@ -379,7 +420,7 @@ class PlatformConfig:
|
|||
if "home_channel" in data:
|
||||
home_channel = HomeChannel.from_dict(data["home_channel"])
|
||||
|
||||
# gateway_restart_notification may be bridged into extra via the
|
||||
# gateway_restart_notification may be bridged into extra via the
|
||||
# shared-key loop in load_gateway_config(); check both top-level
|
||||
# and extra so YAML ``discord: gateway_restart_notification: false``
|
||||
# works without needing a separate platforms: block.
|
||||
|
|
@ -394,14 +435,22 @@ class PlatformConfig:
|
|||
if _typing is None:
|
||||
_typing = data.get("extra", {}).get("typing_indicator")
|
||||
|
||||
channel_overrides: Dict[str, ChannelOverride] = {}
|
||||
raw_overrides = data.get("channel_overrides") or {}
|
||||
if isinstance(raw_overrides, dict):
|
||||
for cid, ov_data in raw_overrides.items():
|
||||
if isinstance(ov_data, dict):
|
||||
channel_overrides[str(cid)] = ChannelOverride.from_dict(ov_data)
|
||||
|
||||
return cls(
|
||||
enabled=_coerce_bool(data.get("enabled"), False),
|
||||
token=data.get("token"),
|
||||
api_key=data.get("api_key"),
|
||||
home_channel=home_channel,
|
||||
reply_to_mode=data.get("reply_to_mode", "first"),
|
||||
gateway_restart_notification=_coerce_bool(_grn, True),
|
||||
gateway_restart_notification=_coerce_bool(_grn, True),
|
||||
typing_indicator=_coerce_bool(_typing, True),
|
||||
channel_overrides=channel_overrides,
|
||||
extra=data.get("extra", {}),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1660,6 +1660,7 @@ if not _configured_cwd or _configured_cwd in {".", "auto", "cwd"}:
|
|||
os.environ["TERMINAL_CWD"] = _fallback
|
||||
|
||||
from gateway.config import (
|
||||
ChannelOverride,
|
||||
Platform,
|
||||
_BUILTIN_PLATFORM_VALUES,
|
||||
GatewayConfig,
|
||||
|
|
@ -1825,6 +1826,27 @@ def _resolve_runtime_agent_kwargs() -> dict:
|
|||
}
|
||||
|
||||
|
||||
def _resolve_runtime_agent_kwargs_for_provider(provider: str) -> dict:
|
||||
"""Resolve runtime credentials for a specific provider (e.g. from channel override)."""
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
format_runtime_provider_error,
|
||||
)
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=provider)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(format_runtime_provider_error(exc)) from exc
|
||||
return {
|
||||
"api_key": runtime.get("api_key"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
"credential_pool": runtime.get("credential_pool"),
|
||||
}
|
||||
|
||||
|
||||
def _try_resolve_fallback_provider() -> dict | None:
|
||||
"""Attempt to resolve credentials from the fallback_model/fallback_providers config."""
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
|
@ -2284,6 +2306,20 @@ def _resolve_gateway_model(config: dict | None = None) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def _get_channel_override(
|
||||
config: GatewayConfig,
|
||||
platform: Platform,
|
||||
chat_id: str,
|
||||
) -> Optional[ChannelOverride]:
|
||||
"""Return per-channel override for this platform/chat_id, or None."""
|
||||
if not chat_id:
|
||||
return None
|
||||
platform_config = config.platforms.get(platform)
|
||||
if not platform_config or not platform_config.channel_overrides:
|
||||
return None
|
||||
return platform_config.channel_overrides.get(str(chat_id))
|
||||
|
||||
|
||||
def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
"""Resolve the Hermes update command as argv parts.
|
||||
|
||||
|
|
@ -3601,6 +3637,25 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
resolved_session_key, model, runtime_kwargs
|
||||
)
|
||||
|
||||
cfg = getattr(self, "config", None)
|
||||
if cfg and source is not None and source.chat_id:
|
||||
ch = _get_channel_override(cfg, source.platform, str(source.chat_id))
|
||||
if ch:
|
||||
channel_touch = False
|
||||
if ch.model and not (override and override.get("model")):
|
||||
model = ch.model
|
||||
channel_touch = True
|
||||
if ch.provider and not (override and override.get("provider")):
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(ch.provider)
|
||||
runtime_model = runtime_kwargs.pop("model", None)
|
||||
if runtime_model:
|
||||
model = runtime_model or model
|
||||
channel_touch = True
|
||||
if channel_touch and override and resolved_session_key:
|
||||
model, runtime_kwargs = self._apply_session_model_override(
|
||||
resolved_session_key, model, runtime_kwargs
|
||||
)
|
||||
|
||||
# When the config has no model.default but a provider was resolved
|
||||
# (e.g. user ran `hermes auth add openai-codex` without `hermes model`),
|
||||
# fall back to the provider's first catalog model so the API call
|
||||
|
|
@ -4473,6 +4528,24 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
cfg = _load_gateway_runtime_config()
|
||||
return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip()
|
||||
|
||||
def _resolve_model_for_channel(self, platform: Platform, chat_id: str) -> str:
|
||||
"""Resolve model for this channel: channel_overrides[channel_id] else global default."""
|
||||
config = getattr(self, "config", None)
|
||||
if config:
|
||||
override = _get_channel_override(config, platform, chat_id)
|
||||
if override and override.model:
|
||||
return override.model
|
||||
return _resolve_gateway_model()
|
||||
|
||||
def _get_system_prompt_for_channel(self, platform: Platform, chat_id: str) -> str:
|
||||
"""System prompt for this channel: channel override else global ephemeral."""
|
||||
config = getattr(self, "config", None)
|
||||
if config:
|
||||
override = _get_channel_override(config, platform, chat_id)
|
||||
if override and override.system_prompt:
|
||||
return (override.system_prompt or "").strip()
|
||||
return getattr(self, "_ephemeral_system_prompt", None) or ""
|
||||
|
||||
@staticmethod
|
||||
def _load_reasoning_config() -> dict | None:
|
||||
"""Load reasoning effort from config.yaml.
|
||||
|
|
@ -16791,14 +16864,18 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
|
||||
platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value
|
||||
|
||||
# Combine platform context, per-channel context, and the user-configured
|
||||
# ephemeral system prompt.
|
||||
# Combine platform context, YAML channel_prompts hint for this chat,
|
||||
# channel_overrides system_prompt (or global ephemeral), and gateway
|
||||
# ephemeral prompt from _get_system_prompt_for_channel.
|
||||
combined_ephemeral = context_prompt or ""
|
||||
event_channel_prompt = (channel_prompt or "").strip()
|
||||
if event_channel_prompt:
|
||||
combined_ephemeral = (combined_ephemeral + "\n\n" + event_channel_prompt).strip()
|
||||
if self._ephemeral_system_prompt:
|
||||
combined_ephemeral = (combined_ephemeral + "\n\n" + self._ephemeral_system_prompt).strip()
|
||||
cfg_channel_prompt = self._get_system_prompt_for_channel(
|
||||
source.platform, source.chat_id or ""
|
||||
)
|
||||
if cfg_channel_prompt:
|
||||
combined_ephemeral = (combined_ephemeral + "\n\n" + cfg_channel_prompt).strip()
|
||||
|
||||
max_iterations = _current_max_iterations()
|
||||
|
||||
|
|
|
|||
128
tests/gateway/test_channel_overrides.py
Normal file
128
tests/gateway/test_channel_overrides.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""Tests for per-channel model and system prompt overrides (Fixes #1955)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
ChannelOverride,
|
||||
GatewayConfig,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
)
|
||||
from gateway.run import _get_channel_override, GatewayRunner
|
||||
|
||||
|
||||
class TestGetChannelOverride:
|
||||
def test_no_override_when_empty_config(self):
|
||||
config = GatewayConfig()
|
||||
assert _get_channel_override(config, Platform.DISCORD, "123") is None
|
||||
|
||||
def test_no_override_when_platform_not_configured(self):
|
||||
config = GatewayConfig(platforms={})
|
||||
assert _get_channel_override(config, Platform.DISCORD, "123") is None
|
||||
|
||||
def test_no_override_when_channel_not_in_overrides(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"999": ChannelOverride(model="openrouter/healer-alpha"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
assert _get_channel_override(config, Platform.DISCORD, "123") is None
|
||||
|
||||
def test_returns_override_when_channel_matches(self):
|
||||
ov = ChannelOverride(
|
||||
model="openrouter/healer-alpha",
|
||||
provider="openrouter",
|
||||
system_prompt="You are a summarizer.",
|
||||
)
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={"1234567890": ov},
|
||||
),
|
||||
},
|
||||
)
|
||||
result = _get_channel_override(config, Platform.DISCORD, "1234567890")
|
||||
assert result is not None
|
||||
assert result.model == "openrouter/healer-alpha"
|
||||
assert result.provider == "openrouter"
|
||||
assert result.system_prompt == "You are a summarizer."
|
||||
|
||||
def test_returns_override_when_chat_id_is_int_like(self):
|
||||
"""Caller may pass str(chat_id); override keys are normalized to str."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={"123": ChannelOverride(model="gpt-4")},
|
||||
),
|
||||
},
|
||||
)
|
||||
assert _get_channel_override(config, Platform.DISCORD, "123").model == "gpt-4"
|
||||
|
||||
|
||||
class TestResolveModelForChannel:
|
||||
def test_uses_channel_override_when_present(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"chan_1": ChannelOverride(model="anthropic/claude-opus-4.6"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
model = runner._resolve_model_for_channel(Platform.DISCORD, "chan_1")
|
||||
assert model == "anthropic/claude-opus-4.6"
|
||||
|
||||
def test_falls_back_to_global_when_no_override(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._resolve_gateway_model",
|
||||
lambda: "global-model/default",
|
||||
)
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, channel_overrides={}),
|
||||
},
|
||||
)
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
model = runner._resolve_model_for_channel(Platform.DISCORD, "unknown_channel")
|
||||
assert model == "global-model/default"
|
||||
|
||||
|
||||
class TestGetSystemPromptForChannel:
|
||||
def test_uses_channel_override_when_present(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"chan_1": ChannelOverride(system_prompt="You are a coding assistant."),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
runner._ephemeral_system_prompt = "Global prompt"
|
||||
prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "chan_1")
|
||||
assert prompt == "You are a coding assistant."
|
||||
|
||||
def test_falls_back_to_global_when_no_override(self):
|
||||
config = GatewayConfig(
|
||||
platforms={Platform.DISCORD: PlatformConfig(enabled=True)},
|
||||
)
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
runner._ephemeral_system_prompt = "Global prompt"
|
||||
prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "other")
|
||||
assert prompt == "Global prompt"
|
||||
|
|
@ -5,6 +5,7 @@ import os
|
|||
from unittest.mock import patch
|
||||
|
||||
from gateway.config import (
|
||||
ChannelOverride,
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
|
|
@ -89,6 +90,54 @@ class TestPlatformConfigRoundtrip:
|
|||
# extra; from_dict must honor it there too (mirrors _grn fallback).
|
||||
restored = PlatformConfig.from_dict({"extra": {"typing_indicator": False}})
|
||||
assert restored.typing_indicator is False
|
||||
def test_channel_overrides_roundtrip(self):
|
||||
pc = PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"1234567890": ChannelOverride(
|
||||
model="openrouter/healer-alpha",
|
||||
provider="openrouter",
|
||||
system_prompt="You are a daily news summarizer.",
|
||||
),
|
||||
"9876543210": ChannelOverride(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
provider="anthropic",
|
||||
system_prompt="You are a coding assistant.",
|
||||
),
|
||||
},
|
||||
)
|
||||
d = pc.to_dict()
|
||||
assert "channel_overrides" in d
|
||||
assert d["channel_overrides"]["1234567890"]["model"] == "openrouter/healer-alpha"
|
||||
assert d["channel_overrides"]["9876543210"]["system_prompt"] == "You are a coding assistant."
|
||||
restored = PlatformConfig.from_dict(d)
|
||||
assert restored.channel_overrides["1234567890"].model == "openrouter/healer-alpha"
|
||||
assert restored.channel_overrides["9876543210"].provider == "anthropic"
|
||||
|
||||
def test_channel_overrides_from_dict_normalizes_channel_id_to_str(self):
|
||||
"""YAML may have numeric channel IDs; we store as str."""
|
||||
data = {
|
||||
"enabled": True,
|
||||
"channel_overrides": {
|
||||
1234567890: {"model": "openrouter/healer-alpha"},
|
||||
},
|
||||
}
|
||||
pc = PlatformConfig.from_dict(data)
|
||||
assert "1234567890" in pc.channel_overrides
|
||||
assert pc.channel_overrides["1234567890"].model == "openrouter/healer-alpha"
|
||||
|
||||
|
||||
class TestChannelOverride:
|
||||
def test_from_dict_empty(self):
|
||||
assert ChannelOverride.from_dict({}).model is None
|
||||
assert ChannelOverride.from_dict(None).model is None
|
||||
|
||||
def test_to_dict_omits_none(self):
|
||||
ov = ChannelOverride(model="gpt-4", provider=None, system_prompt="Hi")
|
||||
d = ov.to_dict()
|
||||
assert d["model"] == "gpt-4"
|
||||
assert "provider" not in d
|
||||
assert d["system_prompt"] == "Hi"
|
||||
|
||||
|
||||
class TestGetConnectedPlatforms:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue