feat(gateway): per-channel model and system prompt overrides (Fixes #1955)
- ChannelOverride + channel_overrides on PlatformConfig - Resolve model/runtime: session /model, then channel_overrides, then global - Thread/parent channel lookup; bridge discord.channel_overrides from YAML - Drop unrelated test and delegate_tool changes from PR scope
This commit is contained in:
parent
ebef73f6b8
commit
0010c14e66
4 changed files with 336 additions and 40 deletions
|
|
@ -379,7 +379,7 @@ class PlatformConfig:
|
|||
# noise; keep True for back-channels where the operator wants them.
|
||||
gateway_restart_notification: bool = True
|
||||
|
||||
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
|
||||
# Whether the gateway shows a "typing…" / "is thinking…" status indicator
|
||||
# while the agent processes a message on this platform. Default True
|
||||
# preserves prior behavior. Set False on platforms where the indicator is
|
||||
# unwanted (e.g. Slack's assistant.threads.setStatus "is thinking…", which
|
||||
|
|
@ -420,7 +420,7 @@ class PlatformConfig:
|
|||
if "home_channel" in data:
|
||||
home_channel = HomeChannel.from_dict(data["home_channel"])
|
||||
|
||||
# gateway_restart_notification may be bridged into extra via the
|
||||
# gateway_restart_notification may be bridged into extra via the
|
||||
# shared-key loop in load_gateway_config(); check both top-level
|
||||
# and extra so YAML ``discord: gateway_restart_notification: false``
|
||||
# works without needing a separate platforms: block.
|
||||
|
|
@ -448,7 +448,7 @@ class PlatformConfig:
|
|||
api_key=data.get("api_key"),
|
||||
home_channel=home_channel,
|
||||
reply_to_mode=data.get("reply_to_mode", "first"),
|
||||
gateway_restart_notification=_coerce_bool(_grn, True),
|
||||
gateway_restart_notification=_coerce_bool(_grn, True),
|
||||
typing_indicator=_coerce_bool(_typing, True),
|
||||
channel_overrides=channel_overrides,
|
||||
extra=data.get("extra", {}),
|
||||
|
|
@ -1103,8 +1103,20 @@ def load_gateway_config() -> GatewayConfig:
|
|||
bridged["gateway_restart_notification"] = platform_cfg["gateway_restart_notification"]
|
||||
if "typing_indicator" in platform_cfg:
|
||||
bridged["typing_indicator"] = platform_cfg["typing_indicator"]
|
||||
has_channel_overrides = "channel_overrides" in platform_cfg
|
||||
if has_channel_overrides:
|
||||
raw_overrides = platform_cfg.get("channel_overrides")
|
||||
if isinstance(raw_overrides, dict):
|
||||
plat_data, _extra = _ensure_platform_extra_dict(
|
||||
platforms_data, plat.value
|
||||
)
|
||||
plat_data["channel_overrides"] = {
|
||||
str(cid): ov_data
|
||||
for cid, ov_data in raw_overrides.items()
|
||||
if isinstance(ov_data, dict)
|
||||
}
|
||||
enabled_was_explicit = _cfg_toplevel and "enabled" in platform_cfg
|
||||
if not bridged and not enabled_was_explicit:
|
||||
if not bridged and not enabled_was_explicit and not has_channel_overrides:
|
||||
continue
|
||||
plat_data, extra = _ensure_platform_extra_dict(platforms_data, plat.value)
|
||||
if enabled_was_explicit:
|
||||
|
|
|
|||
155
gateway/run.py
155
gateway/run.py
|
|
@ -2306,18 +2306,54 @@ def _resolve_gateway_model(config: dict | None = None) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def _channel_override_lookup_keys(
|
||||
chat_id: str,
|
||||
*,
|
||||
thread_id: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Ordered, de-duplicated keys for ``channel_overrides`` lookup.
|
||||
|
||||
Matches ``resolve_channel_prompt`` semantics: exact thread/channel id first,
|
||||
then parent channel/forum id (Discord threads inherit parent overrides).
|
||||
"""
|
||||
keys: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for key in (chat_id, thread_id, parent_id):
|
||||
if not key:
|
||||
continue
|
||||
sk = str(key)
|
||||
if sk in seen:
|
||||
continue
|
||||
seen.add(sk)
|
||||
keys.append(sk)
|
||||
return keys
|
||||
|
||||
|
||||
def _get_channel_override(
|
||||
config: GatewayConfig,
|
||||
platform: Platform,
|
||||
chat_id: str,
|
||||
*,
|
||||
thread_id: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> Optional[ChannelOverride]:
|
||||
"""Return per-channel override for this platform/chat_id, or None."""
|
||||
if not chat_id:
|
||||
return None
|
||||
"""Return per-channel override for this platform/chat_id, or None.
|
||||
|
||||
Looks up ``channel_overrides`` by ``chat_id``, then ``thread_id``, then
|
||||
``parent_id`` (forum threads / child channels inherit the parent entry).
|
||||
"""
|
||||
platform_config = config.platforms.get(platform)
|
||||
if not platform_config or not platform_config.channel_overrides:
|
||||
return None
|
||||
return platform_config.channel_overrides.get(str(chat_id))
|
||||
overrides = platform_config.channel_overrides
|
||||
for key in _channel_override_lookup_keys(
|
||||
chat_id, thread_id=thread_id, parent_id=parent_id
|
||||
):
|
||||
ov = overrides.get(key)
|
||||
if ov is not None:
|
||||
return ov
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_hermes_bin() -> Optional[list[str]]:
|
||||
|
|
@ -3579,11 +3615,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
session_key: Optional[str] = None,
|
||||
user_config: Optional[dict] = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Resolve model/runtime for a session, honoring session-scoped /model overrides.
|
||||
"""Resolve model/runtime for a session.
|
||||
|
||||
If the session override already contains a complete provider bundle
|
||||
(provider/api_key/base_url/api_mode), prefer it directly instead of
|
||||
resolving fresh global runtime state first.
|
||||
Priority (highest first): session ``/model`` → ``channel_overrides`` →
|
||||
global config/env (``_resolve_gateway_model(user_config)`` and default
|
||||
provider resolution).
|
||||
"""
|
||||
resolved_session_key = session_key
|
||||
if not resolved_session_key and source is not None:
|
||||
|
|
@ -3632,30 +3668,43 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
runtime_model,
|
||||
)
|
||||
model = runtime_model
|
||||
|
||||
cfg = getattr(self, "config", None)
|
||||
if cfg and source is not None:
|
||||
chat_id = str(source.chat_id) if source.chat_id else ""
|
||||
thread_id = (
|
||||
str(source.thread_id) if getattr(source, "thread_id", None) else None
|
||||
)
|
||||
parent_id = (
|
||||
str(source.parent_chat_id)
|
||||
if getattr(source, "parent_chat_id", None)
|
||||
else None
|
||||
)
|
||||
ch = _get_channel_override(
|
||||
cfg,
|
||||
source.platform,
|
||||
chat_id,
|
||||
thread_id=thread_id,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
if ch:
|
||||
if ch.model:
|
||||
model = ch.model
|
||||
if ch.provider:
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(
|
||||
ch.provider
|
||||
)
|
||||
ch_runtime_model = runtime_kwargs.pop("model", None)
|
||||
# Only adopt the provider's bundled model when the override
|
||||
# did not specify an explicit model.
|
||||
if ch_runtime_model and not ch.model:
|
||||
model = ch_runtime_model
|
||||
|
||||
if override and resolved_session_key:
|
||||
model, runtime_kwargs = self._apply_session_model_override(
|
||||
resolved_session_key, model, runtime_kwargs
|
||||
)
|
||||
|
||||
cfg = getattr(self, "config", None)
|
||||
if cfg and source is not None and source.chat_id:
|
||||
ch = _get_channel_override(cfg, source.platform, str(source.chat_id))
|
||||
if ch:
|
||||
channel_touch = False
|
||||
if ch.model and not (override and override.get("model")):
|
||||
model = ch.model
|
||||
channel_touch = True
|
||||
if ch.provider and not (override and override.get("provider")):
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs_for_provider(ch.provider)
|
||||
runtime_model = runtime_kwargs.pop("model", None)
|
||||
if runtime_model:
|
||||
model = runtime_model or model
|
||||
channel_touch = True
|
||||
if channel_touch and override and resolved_session_key:
|
||||
model, runtime_kwargs = self._apply_session_model_override(
|
||||
resolved_session_key, model, runtime_kwargs
|
||||
)
|
||||
|
||||
# When the config has no model.default but a provider was resolved
|
||||
# (e.g. user ran `hermes auth add openai-codex` without `hermes model`),
|
||||
# fall back to the provider's first catalog model so the API call
|
||||
|
|
@ -4528,20 +4577,53 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
cfg = _load_gateway_runtime_config()
|
||||
return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip()
|
||||
|
||||
def _resolve_model_for_channel(self, platform: Platform, chat_id: str) -> str:
|
||||
"""Resolve model for this channel: channel_overrides[channel_id] else global default."""
|
||||
def _resolve_model_for_channel(
|
||||
self,
|
||||
platform: Platform,
|
||||
chat_id: str,
|
||||
*,
|
||||
user_config: Optional[dict] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Resolve model for this channel: channel_overrides else global default."""
|
||||
config = getattr(self, "config", None)
|
||||
if config:
|
||||
override = _get_channel_override(config, platform, chat_id)
|
||||
override = _get_channel_override(
|
||||
config,
|
||||
platform,
|
||||
chat_id,
|
||||
thread_id=thread_id,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
if override and override.model:
|
||||
return override.model
|
||||
return _resolve_gateway_model()
|
||||
return _resolve_gateway_model(user_config)
|
||||
|
||||
def _get_system_prompt_for_channel(self, platform: Platform, chat_id: str) -> str:
|
||||
"""System prompt for this channel: channel override else global ephemeral."""
|
||||
def _get_system_prompt_for_channel(
|
||||
self,
|
||||
platform: Platform,
|
||||
chat_id: str,
|
||||
*,
|
||||
thread_id: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Ephemeral system prompt for this channel/thread.
|
||||
|
||||
Uses ``channel_overrides`` when set, else the global gateway prompt.
|
||||
Legacy ``channel_prompts`` are applied separately via ``event.channel_prompt``
|
||||
in ``run_sync`` (adapter ``resolve_channel_prompt``), so they are not
|
||||
duplicated here.
|
||||
"""
|
||||
config = getattr(self, "config", None)
|
||||
if config:
|
||||
override = _get_channel_override(config, platform, chat_id)
|
||||
override = _get_channel_override(
|
||||
config,
|
||||
platform,
|
||||
chat_id,
|
||||
thread_id=thread_id,
|
||||
parent_id=parent_id,
|
||||
)
|
||||
if override and override.system_prompt:
|
||||
return (override.system_prompt or "").strip()
|
||||
return getattr(self, "_ephemeral_system_prompt", None) or ""
|
||||
|
|
@ -16872,7 +16954,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
if event_channel_prompt:
|
||||
combined_ephemeral = (combined_ephemeral + "\n\n" + event_channel_prompt).strip()
|
||||
cfg_channel_prompt = self._get_system_prompt_for_channel(
|
||||
source.platform, source.chat_id or ""
|
||||
source.platform,
|
||||
source.chat_id or "",
|
||||
thread_id=getattr(source, "thread_id", None),
|
||||
parent_id=getattr(source, "parent_chat_id", None),
|
||||
)
|
||||
if cfg_channel_prompt:
|
||||
combined_ephemeral = (combined_ephemeral + "\n\n" + cfg_channel_prompt).strip()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Tests for per-channel model and system prompt overrides (Fixes #1955)."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
|
|
@ -9,6 +11,7 @@ from gateway.config import (
|
|||
PlatformConfig,
|
||||
)
|
||||
from gateway.run import _get_channel_override, GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class TestGetChannelOverride:
|
||||
|
|
@ -65,6 +68,60 @@ class TestGetChannelOverride:
|
|||
)
|
||||
assert _get_channel_override(config, Platform.DISCORD, "123").model == "gpt-4"
|
||||
|
||||
def test_thread_id_lookup_when_chat_id_misses(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"thread_99": ChannelOverride(model="topic-model"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
result = _get_channel_override(
|
||||
config, Platform.DISCORD, "parent_chan", thread_id="thread_99"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.model == "topic-model"
|
||||
|
||||
def test_parent_id_fallback_when_thread_has_no_entry(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"parent_chan": ChannelOverride(model="parent-model"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
result = _get_channel_override(
|
||||
config,
|
||||
Platform.DISCORD,
|
||||
"thread_only",
|
||||
parent_id="parent_chan",
|
||||
)
|
||||
assert result is not None
|
||||
assert result.model == "parent-model"
|
||||
|
||||
def test_exact_thread_overrides_parent(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"thread_1": ChannelOverride(model="thread-model"),
|
||||
"parent_chan": ChannelOverride(model="parent-model"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
result = _get_channel_override(
|
||||
config, Platform.DISCORD, "thread_1", parent_id="parent_chan"
|
||||
)
|
||||
assert result.model == "thread-model"
|
||||
|
||||
|
||||
class TestResolveModelForChannel:
|
||||
def test_uses_channel_override_when_present(self):
|
||||
|
|
@ -86,7 +143,7 @@ class TestResolveModelForChannel:
|
|||
def test_falls_back_to_global_when_no_override(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._resolve_gateway_model",
|
||||
lambda: "global-model/default",
|
||||
lambda _cfg=None: "global-model/default",
|
||||
)
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
|
|
@ -126,3 +183,120 @@ class TestGetSystemPromptForChannel:
|
|||
runner._ephemeral_system_prompt = "Global prompt"
|
||||
prompt = runner._get_system_prompt_for_channel(Platform.DISCORD, "other")
|
||||
assert prompt == "Global prompt"
|
||||
|
||||
|
||||
class TestResolveSessionAgentRuntimePriority:
|
||||
"""Model/runtime priority: session /model → channel_overrides → global."""
|
||||
|
||||
def test_channel_override_beats_global(self):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_model_overrides = {}
|
||||
runner.config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"chan_1": ChannelOverride(
|
||||
model="channel/model",
|
||||
provider="openrouter",
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="chan_1",
|
||||
user_id="u1",
|
||||
)
|
||||
with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
|
||||
"provider": "anthropic",
|
||||
"api_key": "k",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"api_mode": "chat_completions",
|
||||
}), \
|
||||
patch(
|
||||
"gateway.run._resolve_runtime_agent_kwargs_for_provider",
|
||||
return_value={
|
||||
"provider": "openrouter",
|
||||
"api_key": "k2",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
):
|
||||
model, runtime = runner._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
user_config={"model": {"default": "global/model"}},
|
||||
)
|
||||
assert model == "channel/model"
|
||||
assert runtime["provider"] == "openrouter"
|
||||
|
||||
def test_session_model_beats_channel_override(self):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"chan_1": ChannelOverride(model="channel/model"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
session_key = "agent:main:discord:channel:chan_1"
|
||||
runner._session_model_overrides = {
|
||||
session_key: {
|
||||
"model": "session/model",
|
||||
"provider": "anthropic",
|
||||
},
|
||||
}
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="chan_1",
|
||||
chat_type="channel",
|
||||
user_id="u1",
|
||||
)
|
||||
with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
|
||||
"provider": "openrouter",
|
||||
"api_key": "k",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_mode": "chat_completions",
|
||||
}):
|
||||
model, runtime = runner._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
)
|
||||
assert model == "session/model"
|
||||
assert runtime["provider"] == "anthropic"
|
||||
|
||||
def test_parent_channel_model_inherited_in_thread(self):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._session_model_overrides = {}
|
||||
runner.config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(
|
||||
enabled=True,
|
||||
channel_overrides={
|
||||
"parent_chan": ChannelOverride(model="parent/model"),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="thread_1",
|
||||
chat_type="thread",
|
||||
parent_chat_id="parent_chan",
|
||||
user_id="u1",
|
||||
)
|
||||
with patch("gateway.run._resolve_gateway_model", return_value="global/model"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={
|
||||
"provider": "anthropic",
|
||||
"api_key": "k",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"api_mode": "chat_completions",
|
||||
}):
|
||||
model, _runtime = runner._resolve_session_agent_runtime(source=source)
|
||||
assert model == "parent/model"
|
||||
|
|
|
|||
|
|
@ -831,6 +831,31 @@ class TestLoadGatewayConfig:
|
|||
|
||||
assert config.always_log_local is False
|
||||
|
||||
def test_bridges_discord_channel_overrides_from_top_level_yaml(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(
|
||||
"discord:\n"
|
||||
" channel_overrides:\n"
|
||||
' "1234567890":\n'
|
||||
" model: openrouter/healer-alpha\n"
|
||||
" provider: openrouter\n"
|
||||
" system_prompt: Daily news summarizer\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
discord = config.platforms[Platform.DISCORD]
|
||||
assert "1234567890" in discord.channel_overrides
|
||||
ov = discord.channel_overrides["1234567890"]
|
||||
assert ov.model == "openrouter/healer-alpha"
|
||||
assert ov.provider == "openrouter"
|
||||
assert ov.system_prompt == "Daily news summarizer"
|
||||
|
||||
def test_bridges_discord_channel_prompts_from_config_yaml(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue