fix(discord): channel name matching and flush pending sends on shutdown

Two related fixes to the Discord gateway adapter:

1. Channel name matching (free-response, allowed, ignored, no-thread channels)
   Previously these config values only matched against numeric channel IDs.
   If a user configured free_response_channels: cypher (by name), the adapter
   would silently ignore it because it only intersected against channel_ids.
   Now the adapter builds a channel_keys set that includes the channel ID,
   channel name, and #channel-name form, and checks all three for each gate.

2. Flush pending text-batch tasks before shutdown
   The Discord adapter uses _pending_text_batch_tasks (its own dict) for
   merging rapid successive message chunks. These tasks were NOT added to
   self._background_tasks (the base class list), so the base
   cancel_background_tasks() never awaited them on restart/shutdown.
   This caused a race: in-flight response deliveries were cancelled before
   Discord had a chance to send them, resulting in silent dropped messages
   visible to users as tool-log-only replies with no text body.

   Fix: override cancel_background_tasks() in DiscordAdapter to await all
   pending text-batch tasks (8s deadline) before delegating to the base class.
This commit is contained in:
Cypher 2026-04-12 02:13:00 +02:00 committed by Teknium
parent b03635daea
commit cb9308f0a6
3 changed files with 270 additions and 15 deletions

View file

@ -1104,10 +1104,8 @@ class DiscordAdapter(BasePlatformAdapter):
if hasattr(message.channel, "parent_id") and message.channel.parent_id:
_parent_id = str(message.channel.parent_id)
_free_channels = adapter_self._discord_free_response_channels()
_channel_ids = {_channel_id}
if _parent_id:
_channel_ids.add(_parent_id)
if "*" not in _free_channels and not (_channel_ids & _free_channels):
_channel_keys = adapter_self._discord_channel_keys(message, _parent_id)
if "*" not in _free_channels and not (_channel_keys & _free_channels):
return
await self._handle_message(message, role_authorized=_role_authorized)
@ -1276,6 +1274,67 @@ class DiscordAdapter(BasePlatformAdapter):
pass
self._liveness_task = None
async def cancel_background_tasks(self) -> None:
"""Cancel background tasks, but first flush any pending text-batch sends.
The base-class implementation only cancels tasks in self._background_tasks.
Discord keeps its own _pending_text_batch_tasks dict for the message-merge
logic, and those tasks are NOT in _background_tasks. On shutdown/restart
this caused a race where in-flight response deliveries were cancelled before
Discord had a chance to actually send them, resulting in silent dropped
messages visible to the user as tool-log-only replies with no text.
Fix: await all pending text-batch tasks before delegating to the base
cancel. The flush deadline is clamped below the gateway's per-adapter
disconnect budget (``HERMES_GATEWAY_ADAPTER_DISCONNECT_TIMEOUT``, default
5s) so the gateway's outer ``wait_for`` can't hard-cancel us mid-flush
we cancel our own stragglers cleanly inside the budget instead.
"""
pending = list(self._pending_text_batch_tasks.values())
if pending:
logger.info(
"[%s] Flushing %d pending text-batch task(s) before shutdown",
self.name, len(pending),
)
try:
await asyncio.wait_for(
asyncio.gather(*pending, return_exceptions=True),
timeout=self._text_batch_flush_deadline_seconds(),
)
except asyncio.TimeoutError:
logger.warning(
"[%s] Text-batch flush timed out; cancelling remaining tasks",
self.name,
)
for task in pending:
if not task.done():
task.cancel()
self._pending_text_batch_tasks.clear()
self._pending_text_batches.clear()
await super().cancel_background_tasks()
def _text_batch_flush_deadline_seconds(self) -> float:
"""Deadline for flushing pending text batches during shutdown.
Kept strictly below the gateway's per-adapter disconnect budget so the
gateway's outer ``asyncio.wait_for`` (which wraps this whole method) does
not cancel an in-progress flush before we get a chance to cancel our own
stragglers gracefully. Mirrors the env var the gateway reads in
``GatewayRunner._adapter_disconnect_timeout_secs``.
"""
budget = 5.0 # mirrors gateway _ADAPTER_DISCONNECT_TIMEOUT_SECS_DEFAULT
raw = os.getenv("HERMES_GATEWAY_ADAPTER_DISCONNECT_TIMEOUT", "").strip()
if raw:
try:
parsed = float(raw)
if parsed > 0:
budget = parsed
except ValueError:
pass
# Leave ~20% headroom (min 0.5s) so the outer wait_for can't pre-empt our
# own straggler cancellation, and never go below 1s for the happy path.
return max(1.0, budget - max(0.5, budget * 0.2))
async def disconnect(self) -> None:
"""Disconnect from Discord."""
self._disconnecting = True
@ -2993,6 +3052,16 @@ class DiscordAdapter(BasePlatformAdapter):
if parent_id:
channel_ids.add(str(parent_id))
# Name-form keys (ID + bare name + #name + parent) so allow/ignore
# lists configured by channel name work for slash-command
# interactions too, matching the on_message gates.
channel_keys = self._discord_channel_keys_from_channel(
chan_obj,
self._get_parent_channel_id(chan_obj)
if isinstance(chan_obj, discord.Thread)
else None,
)
allowed_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
if allowed_raw:
allowed = {c.strip() for c in allowed_raw.split(",") if c.strip()}
@ -3004,7 +3073,7 @@ class DiscordAdapter(BasePlatformAdapter):
False,
"channel id missing with DISCORD_ALLOWED_CHANNELS configured",
)
if not (channel_ids & allowed):
if not (channel_keys & allowed):
return (False, "channel not in DISCORD_ALLOWED_CHANNELS")
# Ignored beats allowed: even when a thread's parent channel
@ -3013,7 +3082,7 @@ class DiscordAdapter(BasePlatformAdapter):
ignored_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
if ignored_raw and channel_ids:
ignored = {c.strip() for c in ignored_raw.split(",") if c.strip()}
if "*" in ignored or (channel_ids & ignored):
if "*" in ignored or (channel_keys & ignored):
return (False, "channel in DISCORD_IGNORED_CHANNELS")
# ── User / role allowlist (mirrors on_message line 681) ──
@ -4336,7 +4405,7 @@ class DiscordAdapter(BasePlatformAdapter):
)
def _discord_free_response_channels(self) -> set:
"""Return Discord channel IDs where no bot mention is required.
"""Return Discord channel IDs/names where no bot mention is required.
A single ``"*"`` entry (either from a list or a comma-separated
string) is preserved in the returned set so callers can short-circuit
@ -4358,6 +4427,50 @@ class DiscordAdapter(BasePlatformAdapter):
return {part.strip() for part in s.split(",") if part.strip()}
return set()
def _discord_channel_keys(self, message: Any, parent_channel_id: Optional[str] = None) -> set[str]:
"""Return channel identifiers accepted by Discord channel config gates.
Users commonly configure channels by Discord snowflake ID, bare name, or
``#name``. Include the current channel and, for threads, the parent
channel so free-response/no-thread/allow/ignore rules work with either
form.
"""
channel = getattr(message, "channel", None)
return self._discord_channel_keys_from_channel(channel, parent_channel_id)
def _discord_channel_keys_from_channel(
self, channel: Any, parent_channel_id: Optional[str] = None
) -> set[str]:
"""Build channel-config gate keys directly from a channel object.
Same key set as :meth:`_discord_channel_keys` (ID, bare name, ``#name``,
and the parent channel for threads) but takes the channel directly so
callers holding an ``interaction.channel`` (slash-command authorization)
get name-form matching too not just the ``on_message`` path.
"""
keys: set[str] = set()
channel_id = getattr(channel, "id", None)
if channel_id is not None:
keys.add(str(channel_id))
channel_name = str(getattr(channel, "name", "")).strip()
if channel_name:
keys.add(channel_name)
keys.add(f"#{channel_name}")
parent_id = parent_channel_id or getattr(channel, "parent_id", None)
if parent_id:
keys.add(str(parent_id))
parent_channel = getattr(channel, "parent", None)
parent_name = str(getattr(parent_channel, "name", "")).strip() if parent_channel else ""
if parent_name:
keys.add(parent_name)
keys.add(f"#{parent_name}")
return keys
def _discord_thread_require_mention(self) -> bool:
"""Return whether thread participation requires @mention to follow up.
@ -5350,25 +5463,24 @@ class DiscordAdapter(BasePlatformAdapter):
channel_ids = {str(message.channel.id)}
if parent_channel_id:
channel_ids.add(parent_channel_id)
channel_keys = self._discord_channel_keys(message, parent_channel_id)
# Check allowed channels - if set, only respond in these channels
allowed_channels_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
if allowed_channels_raw:
allowed_channels = {ch.strip() for ch in allowed_channels_raw.split(",") if ch.strip()}
if "*" not in allowed_channels and not (channel_ids & allowed_channels):
logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_ids)
if "*" not in allowed_channels and not (channel_keys & allowed_channels):
logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_keys)
return
# Check ignored channels - never respond even when mentioned
ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()}
if "*" in ignored_channels or (channel_ids & ignored_channels):
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
if "*" in ignored_channels or (channel_keys & ignored_channels):
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_keys)
return
free_channels = self._discord_free_response_channels()
if parent_channel_id:
channel_ids.add(parent_channel_id)
require_mention = self._discord_require_mention()
# Voice-linked text channels act as free-response while voice is active.
@ -5378,7 +5490,7 @@ class DiscordAdapter(BasePlatformAdapter):
is_voice_linked_channel = current_channel_id in voice_linked_ids
is_free_channel = (
"*" in free_channels
or bool(channel_ids & free_channels)
or bool(channel_keys & free_channels)
or is_voice_linked_channel
)
@ -5404,7 +5516,7 @@ class DiscordAdapter(BasePlatformAdapter):
if not is_thread and not isinstance(message.channel, discord.DMChannel):
no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "")
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
skip_thread = bool(channel_ids & no_thread_channels) or is_free_channel
skip_thread = bool(channel_keys & no_thread_channels) or is_free_channel
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in {"true", "1", "yes"}
is_reply_message = getattr(message, "type", None) == discord.MessageType.reply
if auto_thread and not skip_thread and not is_voice_linked_channel and not is_reply_message:

View file

@ -410,6 +410,74 @@ async def test_discord_reply_message_skips_auto_thread(adapter, monkeypatch):
assert event.source.chat_type == "group"
@pytest.mark.asyncio
async def test_discord_free_response_matches_channel_name(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "cypher")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
message = make_message(
channel=FakeTextChannel(channel_id=123, name="cypher"),
content="name-configured channel without mention",
)
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.text == "name-configured channel without mention"
@pytest.mark.asyncio
async def test_discord_free_response_matches_hash_channel_name(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "#cypher")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
message = make_message(
channel=FakeTextChannel(channel_id=123, name="cypher"),
content="hash-name-configured channel without mention",
)
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
@pytest.mark.asyncio
async def test_discord_parent_channel_name_matches_thread_gates(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "#cypher")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
parent = FakeTextChannel(channel_id=123, name="cypher")
thread = FakeThread(channel_id=456, name="topic", parent=parent)
message = make_message(channel=thread, content="thread message without mention")
await adapter._handle_message(message)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.source.thread_id == "456"
@pytest.mark.asyncio
async def test_discord_no_thread_matches_channel_name(adapter, monkeypatch):
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "cypher")
adapter._auto_create_thread = AsyncMock()
message = make_message(channel=FakeTextChannel(channel_id=123, name="cypher"), content="hello")
await adapter._handle_message(message)
adapter._auto_create_thread.assert_not_awaited()
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert event.source.chat_type == "group"
@pytest.mark.asyncio
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
"""Setting auto_thread to false skips thread creation."""

View file

@ -0,0 +1,75 @@
"""Regression guard for Discord text-batch flush during gateway shutdown."""
import asyncio
import sys
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource
def _ensure_discord_mock():
"""Install a mock discord module when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
@pytest.mark.asyncio
async def test_cancel_background_tasks_awaits_pending_text_batch_before_clearing():
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="fake-token"))
flushed = asyncio.Event()
async def pending_flush():
await asyncio.sleep(0)
flushed.set()
task = asyncio.create_task(pending_flush())
adapter._pending_text_batch_tasks["chat"] = task
adapter._pending_text_batches["chat"] = MessageEvent(
text="pending",
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.DISCORD, chat_id="chat", chat_type="group"),
)
await adapter.cancel_background_tasks()
assert flushed.is_set()
assert task.done()
assert adapter._pending_text_batch_tasks == {}
assert adapter._pending_text_batches == {}