fix(discord): channel name matching and flush pending sends on shutdown
Two related fixes to the Discord gateway adapter: 1. Channel name matching (free-response, allowed, ignored, no-thread channels) Previously these config values only matched against numeric channel IDs. If a user configured free_response_channels: cypher (by name), the adapter would silently ignore it because it only intersected against channel_ids. Now the adapter builds a channel_keys set that includes the channel ID, channel name, and #channel-name form, and checks all three for each gate. 2. Flush pending text-batch tasks before shutdown The Discord adapter uses _pending_text_batch_tasks (its own dict) for merging rapid successive message chunks. These tasks were NOT added to self._background_tasks (the base class list), so the base cancel_background_tasks() never awaited them on restart/shutdown. This caused a race: in-flight response deliveries were cancelled before Discord had a chance to send them, resulting in silent dropped messages visible to users as tool-log-only replies with no text body. Fix: override cancel_background_tasks() in DiscordAdapter to await all pending text-batch tasks (8s deadline) before delegating to the base class.
This commit is contained in:
parent
b03635daea
commit
cb9308f0a6
3 changed files with 270 additions and 15 deletions
|
|
@ -1104,10 +1104,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if hasattr(message.channel, "parent_id") and message.channel.parent_id:
|
||||
_parent_id = str(message.channel.parent_id)
|
||||
_free_channels = adapter_self._discord_free_response_channels()
|
||||
_channel_ids = {_channel_id}
|
||||
if _parent_id:
|
||||
_channel_ids.add(_parent_id)
|
||||
if "*" not in _free_channels and not (_channel_ids & _free_channels):
|
||||
_channel_keys = adapter_self._discord_channel_keys(message, _parent_id)
|
||||
if "*" not in _free_channels and not (_channel_keys & _free_channels):
|
||||
return
|
||||
|
||||
await self._handle_message(message, role_authorized=_role_authorized)
|
||||
|
|
@ -1276,6 +1274,67 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
pass
|
||||
self._liveness_task = None
|
||||
|
||||
async def cancel_background_tasks(self) -> None:
|
||||
"""Cancel background tasks, but first flush any pending text-batch sends.
|
||||
|
||||
The base-class implementation only cancels tasks in self._background_tasks.
|
||||
Discord keeps its own _pending_text_batch_tasks dict for the message-merge
|
||||
logic, and those tasks are NOT in _background_tasks. On shutdown/restart
|
||||
this caused a race where in-flight response deliveries were cancelled before
|
||||
Discord had a chance to actually send them, resulting in silent dropped
|
||||
messages visible to the user as tool-log-only replies with no text.
|
||||
|
||||
Fix: await all pending text-batch tasks before delegating to the base
|
||||
cancel. The flush deadline is clamped below the gateway's per-adapter
|
||||
disconnect budget (``HERMES_GATEWAY_ADAPTER_DISCONNECT_TIMEOUT``, default
|
||||
5s) so the gateway's outer ``wait_for`` can't hard-cancel us mid-flush —
|
||||
we cancel our own stragglers cleanly inside the budget instead.
|
||||
"""
|
||||
pending = list(self._pending_text_batch_tasks.values())
|
||||
if pending:
|
||||
logger.info(
|
||||
"[%s] Flushing %d pending text-batch task(s) before shutdown",
|
||||
self.name, len(pending),
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*pending, return_exceptions=True),
|
||||
timeout=self._text_batch_flush_deadline_seconds(),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[%s] Text-batch flush timed out; cancelling remaining tasks",
|
||||
self.name,
|
||||
)
|
||||
for task in pending:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
self._pending_text_batch_tasks.clear()
|
||||
self._pending_text_batches.clear()
|
||||
await super().cancel_background_tasks()
|
||||
|
||||
def _text_batch_flush_deadline_seconds(self) -> float:
|
||||
"""Deadline for flushing pending text batches during shutdown.
|
||||
|
||||
Kept strictly below the gateway's per-adapter disconnect budget so the
|
||||
gateway's outer ``asyncio.wait_for`` (which wraps this whole method) does
|
||||
not cancel an in-progress flush before we get a chance to cancel our own
|
||||
stragglers gracefully. Mirrors the env var the gateway reads in
|
||||
``GatewayRunner._adapter_disconnect_timeout_secs``.
|
||||
"""
|
||||
budget = 5.0 # mirrors gateway _ADAPTER_DISCONNECT_TIMEOUT_SECS_DEFAULT
|
||||
raw = os.getenv("HERMES_GATEWAY_ADAPTER_DISCONNECT_TIMEOUT", "").strip()
|
||||
if raw:
|
||||
try:
|
||||
parsed = float(raw)
|
||||
if parsed > 0:
|
||||
budget = parsed
|
||||
except ValueError:
|
||||
pass
|
||||
# Leave ~20% headroom (min 0.5s) so the outer wait_for can't pre-empt our
|
||||
# own straggler cancellation, and never go below 1s for the happy path.
|
||||
return max(1.0, budget - max(0.5, budget * 0.2))
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
self._disconnecting = True
|
||||
|
|
@ -2993,6 +3052,16 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if parent_id:
|
||||
channel_ids.add(str(parent_id))
|
||||
|
||||
# Name-form keys (ID + bare name + #name + parent) so allow/ignore
|
||||
# lists configured by channel name work for slash-command
|
||||
# interactions too, matching the on_message gates.
|
||||
channel_keys = self._discord_channel_keys_from_channel(
|
||||
chan_obj,
|
||||
self._get_parent_channel_id(chan_obj)
|
||||
if isinstance(chan_obj, discord.Thread)
|
||||
else None,
|
||||
)
|
||||
|
||||
allowed_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
|
||||
if allowed_raw:
|
||||
allowed = {c.strip() for c in allowed_raw.split(",") if c.strip()}
|
||||
|
|
@ -3004,7 +3073,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
False,
|
||||
"channel id missing with DISCORD_ALLOWED_CHANNELS configured",
|
||||
)
|
||||
if not (channel_ids & allowed):
|
||||
if not (channel_keys & allowed):
|
||||
return (False, "channel not in DISCORD_ALLOWED_CHANNELS")
|
||||
|
||||
# Ignored beats allowed: even when a thread's parent channel
|
||||
|
|
@ -3013,7 +3082,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
ignored_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
if ignored_raw and channel_ids:
|
||||
ignored = {c.strip() for c in ignored_raw.split(",") if c.strip()}
|
||||
if "*" in ignored or (channel_ids & ignored):
|
||||
if "*" in ignored or (channel_keys & ignored):
|
||||
return (False, "channel in DISCORD_IGNORED_CHANNELS")
|
||||
|
||||
# ── User / role allowlist (mirrors on_message line 681) ──
|
||||
|
|
@ -4336,7 +4405,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
def _discord_free_response_channels(self) -> set:
|
||||
"""Return Discord channel IDs where no bot mention is required.
|
||||
"""Return Discord channel IDs/names where no bot mention is required.
|
||||
|
||||
A single ``"*"`` entry (either from a list or a comma-separated
|
||||
string) is preserved in the returned set so callers can short-circuit
|
||||
|
|
@ -4358,6 +4427,50 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return {part.strip() for part in s.split(",") if part.strip()}
|
||||
return set()
|
||||
|
||||
def _discord_channel_keys(self, message: Any, parent_channel_id: Optional[str] = None) -> set[str]:
|
||||
"""Return channel identifiers accepted by Discord channel config gates.
|
||||
|
||||
Users commonly configure channels by Discord snowflake ID, bare name, or
|
||||
``#name``. Include the current channel and, for threads, the parent
|
||||
channel so free-response/no-thread/allow/ignore rules work with either
|
||||
form.
|
||||
"""
|
||||
channel = getattr(message, "channel", None)
|
||||
return self._discord_channel_keys_from_channel(channel, parent_channel_id)
|
||||
|
||||
def _discord_channel_keys_from_channel(
|
||||
self, channel: Any, parent_channel_id: Optional[str] = None
|
||||
) -> set[str]:
|
||||
"""Build channel-config gate keys directly from a channel object.
|
||||
|
||||
Same key set as :meth:`_discord_channel_keys` (ID, bare name, ``#name``,
|
||||
and the parent channel for threads) but takes the channel directly so
|
||||
callers holding an ``interaction.channel`` (slash-command authorization)
|
||||
get name-form matching too — not just the ``on_message`` path.
|
||||
"""
|
||||
keys: set[str] = set()
|
||||
|
||||
channel_id = getattr(channel, "id", None)
|
||||
if channel_id is not None:
|
||||
keys.add(str(channel_id))
|
||||
|
||||
channel_name = str(getattr(channel, "name", "")).strip()
|
||||
if channel_name:
|
||||
keys.add(channel_name)
|
||||
keys.add(f"#{channel_name}")
|
||||
|
||||
parent_id = parent_channel_id or getattr(channel, "parent_id", None)
|
||||
if parent_id:
|
||||
keys.add(str(parent_id))
|
||||
|
||||
parent_channel = getattr(channel, "parent", None)
|
||||
parent_name = str(getattr(parent_channel, "name", "")).strip() if parent_channel else ""
|
||||
if parent_name:
|
||||
keys.add(parent_name)
|
||||
keys.add(f"#{parent_name}")
|
||||
|
||||
return keys
|
||||
|
||||
def _discord_thread_require_mention(self) -> bool:
|
||||
"""Return whether thread participation requires @mention to follow up.
|
||||
|
||||
|
|
@ -5350,25 +5463,24 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
channel_ids = {str(message.channel.id)}
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
channel_keys = self._discord_channel_keys(message, parent_channel_id)
|
||||
|
||||
# Check allowed channels - if set, only respond in these channels
|
||||
allowed_channels_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "")
|
||||
if allowed_channels_raw:
|
||||
allowed_channels = {ch.strip() for ch in allowed_channels_raw.split(",") if ch.strip()}
|
||||
if "*" not in allowed_channels and not (channel_ids & allowed_channels):
|
||||
logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_ids)
|
||||
if "*" not in allowed_channels and not (channel_keys & allowed_channels):
|
||||
logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_keys)
|
||||
return
|
||||
|
||||
# Check ignored channels - never respond even when mentioned
|
||||
ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "")
|
||||
ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()}
|
||||
if "*" in ignored_channels or (channel_ids & ignored_channels):
|
||||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids)
|
||||
if "*" in ignored_channels or (channel_keys & ignored_channels):
|
||||
logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_keys)
|
||||
return
|
||||
|
||||
free_channels = self._discord_free_response_channels()
|
||||
if parent_channel_id:
|
||||
channel_ids.add(parent_channel_id)
|
||||
|
||||
require_mention = self._discord_require_mention()
|
||||
# Voice-linked text channels act as free-response while voice is active.
|
||||
|
|
@ -5378,7 +5490,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
is_voice_linked_channel = current_channel_id in voice_linked_ids
|
||||
is_free_channel = (
|
||||
"*" in free_channels
|
||||
or bool(channel_ids & free_channels)
|
||||
or bool(channel_keys & free_channels)
|
||||
or is_voice_linked_channel
|
||||
)
|
||||
|
||||
|
|
@ -5404,7 +5516,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if not is_thread and not isinstance(message.channel, discord.DMChannel):
|
||||
no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "")
|
||||
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
|
||||
skip_thread = bool(channel_ids & no_thread_channels) or is_free_channel
|
||||
skip_thread = bool(channel_keys & no_thread_channels) or is_free_channel
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in {"true", "1", "yes"}
|
||||
is_reply_message = getattr(message, "type", None) == discord.MessageType.reply
|
||||
if auto_thread and not skip_thread and not is_voice_linked_channel and not is_reply_message:
|
||||
|
|
|
|||
|
|
@ -410,6 +410,74 @@ async def test_discord_reply_message_skips_auto_thread(adapter, monkeypatch):
|
|||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_matches_channel_name(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "cypher")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=123, name="cypher"),
|
||||
content="name-configured channel without mention",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "name-configured channel without mention"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_free_response_matches_hash_channel_name(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "#cypher")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=123, name="cypher"),
|
||||
content="hash-name-configured channel without mention",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_parent_channel_name_matches_thread_gates(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "#cypher")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
parent = FakeTextChannel(channel_id=123, name="cypher")
|
||||
thread = FakeThread(channel_id=456, name="topic", parent=parent)
|
||||
message = make_message(channel=thread, content="thread message without mention")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.thread_id == "456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_no_thread_matches_channel_name(adapter, monkeypatch):
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "cypher")
|
||||
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
message = make_message(channel=FakeTextChannel(channel_id=123, name="cypher"), content="hello")
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch):
|
||||
"""Setting auto_thread to false skips thread creation."""
|
||||
|
|
|
|||
75
tests/gateway/test_discord_pending_text_batch_shutdown.py
Normal file
75
tests/gateway/test_discord_pending_text_batch_shutdown.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""Regression guard for Discord text-batch flush during gateway shutdown."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_background_tasks_awaits_pending_text_batch_before_clearing():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="fake-token"))
|
||||
flushed = asyncio.Event()
|
||||
|
||||
async def pending_flush():
|
||||
await asyncio.sleep(0)
|
||||
flushed.set()
|
||||
|
||||
task = asyncio.create_task(pending_flush())
|
||||
adapter._pending_text_batch_tasks["chat"] = task
|
||||
adapter._pending_text_batches["chat"] = MessageEvent(
|
||||
text="pending",
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(platform=Platform.DISCORD, chat_id="chat", chat_type="group"),
|
||||
)
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
assert flushed.is_set()
|
||||
assert task.done()
|
||||
assert adapter._pending_text_batch_tasks == {}
|
||||
assert adapter._pending_text_batches == {}
|
||||
Loading…
Add table
Add a link
Reference in a new issue