diff --git a/plugins/platforms/discord/adapter.py b/plugins/platforms/discord/adapter.py index 16aa51246..2595fc702 100644 --- a/plugins/platforms/discord/adapter.py +++ b/plugins/platforms/discord/adapter.py @@ -1104,10 +1104,8 @@ class DiscordAdapter(BasePlatformAdapter): if hasattr(message.channel, "parent_id") and message.channel.parent_id: _parent_id = str(message.channel.parent_id) _free_channels = adapter_self._discord_free_response_channels() - _channel_ids = {_channel_id} - if _parent_id: - _channel_ids.add(_parent_id) - if "*" not in _free_channels and not (_channel_ids & _free_channels): + _channel_keys = adapter_self._discord_channel_keys(message, _parent_id) + if "*" not in _free_channels and not (_channel_keys & _free_channels): return await self._handle_message(message, role_authorized=_role_authorized) @@ -1276,6 +1274,67 @@ class DiscordAdapter(BasePlatformAdapter): pass self._liveness_task = None + async def cancel_background_tasks(self) -> None: + """Cancel background tasks, but first flush any pending text-batch sends. + + The base-class implementation only cancels tasks in self._background_tasks. + Discord keeps its own _pending_text_batch_tasks dict for the message-merge + logic, and those tasks are NOT in _background_tasks. On shutdown/restart + this caused a race where in-flight response deliveries were cancelled before + Discord had a chance to actually send them, resulting in silent dropped + messages visible to the user as tool-log-only replies with no text. + + Fix: await all pending text-batch tasks before delegating to the base + cancel. The flush deadline is clamped below the gateway's per-adapter + disconnect budget (``HERMES_GATEWAY_ADAPTER_DISCONNECT_TIMEOUT``, default + 5s) so the gateway's outer ``wait_for`` can't hard-cancel us mid-flush — + we cancel our own stragglers cleanly inside the budget instead. + """ + pending = list(self._pending_text_batch_tasks.values()) + if pending: + logger.info( + "[%s] Flushing %d pending text-batch task(s) before shutdown", + self.name, len(pending), + ) + try: + await asyncio.wait_for( + asyncio.gather(*pending, return_exceptions=True), + timeout=self._text_batch_flush_deadline_seconds(), + ) + except asyncio.TimeoutError: + logger.warning( + "[%s] Text-batch flush timed out; cancelling remaining tasks", + self.name, + ) + for task in pending: + if not task.done(): + task.cancel() + self._pending_text_batch_tasks.clear() + self._pending_text_batches.clear() + await super().cancel_background_tasks() + + def _text_batch_flush_deadline_seconds(self) -> float: + """Deadline for flushing pending text batches during shutdown. + + Kept strictly below the gateway's per-adapter disconnect budget so the + gateway's outer ``asyncio.wait_for`` (which wraps this whole method) does + not cancel an in-progress flush before we get a chance to cancel our own + stragglers gracefully. Mirrors the env var the gateway reads in + ``GatewayRunner._adapter_disconnect_timeout_secs``. + """ + budget = 5.0 # mirrors gateway _ADAPTER_DISCONNECT_TIMEOUT_SECS_DEFAULT + raw = os.getenv("HERMES_GATEWAY_ADAPTER_DISCONNECT_TIMEOUT", "").strip() + if raw: + try: + parsed = float(raw) + if parsed > 0: + budget = parsed + except ValueError: + pass + # Leave ~20% headroom (min 0.5s) so the outer wait_for can't pre-empt our + # own straggler cancellation, and never go below 1s for the happy path. + return max(1.0, budget - max(0.5, budget * 0.2)) + async def disconnect(self) -> None: """Disconnect from Discord.""" self._disconnecting = True @@ -2993,6 +3052,16 @@ class DiscordAdapter(BasePlatformAdapter): if parent_id: channel_ids.add(str(parent_id)) + # Name-form keys (ID + bare name + #name + parent) so allow/ignore + # lists configured by channel name work for slash-command + # interactions too, matching the on_message gates. + channel_keys = self._discord_channel_keys_from_channel( + chan_obj, + self._get_parent_channel_id(chan_obj) + if isinstance(chan_obj, discord.Thread) + else None, + ) + allowed_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "") if allowed_raw: allowed = {c.strip() for c in allowed_raw.split(",") if c.strip()} @@ -3004,7 +3073,7 @@ class DiscordAdapter(BasePlatformAdapter): False, "channel id missing with DISCORD_ALLOWED_CHANNELS configured", ) - if not (channel_ids & allowed): + if not (channel_keys & allowed): return (False, "channel not in DISCORD_ALLOWED_CHANNELS") # Ignored beats allowed: even when a thread's parent channel @@ -3013,7 +3082,7 @@ class DiscordAdapter(BasePlatformAdapter): ignored_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "") if ignored_raw and channel_ids: ignored = {c.strip() for c in ignored_raw.split(",") if c.strip()} - if "*" in ignored or (channel_ids & ignored): + if "*" in ignored or (channel_keys & ignored): return (False, "channel in DISCORD_IGNORED_CHANNELS") # ── User / role allowlist (mirrors on_message line 681) ── @@ -4336,7 +4405,7 @@ class DiscordAdapter(BasePlatformAdapter): ) def _discord_free_response_channels(self) -> set: - """Return Discord channel IDs where no bot mention is required. + """Return Discord channel IDs/names where no bot mention is required. A single ``"*"`` entry (either from a list or a comma-separated string) is preserved in the returned set so callers can short-circuit @@ -4358,6 +4427,50 @@ class DiscordAdapter(BasePlatformAdapter): return {part.strip() for part in s.split(",") if part.strip()} return set() + def _discord_channel_keys(self, message: Any, parent_channel_id: Optional[str] = None) -> set[str]: + """Return channel identifiers accepted by Discord channel config gates. + + Users commonly configure channels by Discord snowflake ID, bare name, or + ``#name``. Include the current channel and, for threads, the parent + channel so free-response/no-thread/allow/ignore rules work with either + form. + """ + channel = getattr(message, "channel", None) + return self._discord_channel_keys_from_channel(channel, parent_channel_id) + + def _discord_channel_keys_from_channel( + self, channel: Any, parent_channel_id: Optional[str] = None + ) -> set[str]: + """Build channel-config gate keys directly from a channel object. + + Same key set as :meth:`_discord_channel_keys` (ID, bare name, ``#name``, + and the parent channel for threads) but takes the channel directly so + callers holding an ``interaction.channel`` (slash-command authorization) + get name-form matching too — not just the ``on_message`` path. + """ + keys: set[str] = set() + + channel_id = getattr(channel, "id", None) + if channel_id is not None: + keys.add(str(channel_id)) + + channel_name = str(getattr(channel, "name", "")).strip() + if channel_name: + keys.add(channel_name) + keys.add(f"#{channel_name}") + + parent_id = parent_channel_id or getattr(channel, "parent_id", None) + if parent_id: + keys.add(str(parent_id)) + + parent_channel = getattr(channel, "parent", None) + parent_name = str(getattr(parent_channel, "name", "")).strip() if parent_channel else "" + if parent_name: + keys.add(parent_name) + keys.add(f"#{parent_name}") + + return keys + def _discord_thread_require_mention(self) -> bool: """Return whether thread participation requires @mention to follow up. @@ -5350,25 +5463,24 @@ class DiscordAdapter(BasePlatformAdapter): channel_ids = {str(message.channel.id)} if parent_channel_id: channel_ids.add(parent_channel_id) + channel_keys = self._discord_channel_keys(message, parent_channel_id) # Check allowed channels - if set, only respond in these channels allowed_channels_raw = os.getenv("DISCORD_ALLOWED_CHANNELS", "") if allowed_channels_raw: allowed_channels = {ch.strip() for ch in allowed_channels_raw.split(",") if ch.strip()} - if "*" not in allowed_channels and not (channel_ids & allowed_channels): - logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_ids) + if "*" not in allowed_channels and not (channel_keys & allowed_channels): + logger.debug("[%s] Ignoring message in non-allowed channel: %s", self.name, channel_keys) return # Check ignored channels - never respond even when mentioned ignored_channels_raw = os.getenv("DISCORD_IGNORED_CHANNELS", "") ignored_channels = {ch.strip() for ch in ignored_channels_raw.split(",") if ch.strip()} - if "*" in ignored_channels or (channel_ids & ignored_channels): - logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_ids) + if "*" in ignored_channels or (channel_keys & ignored_channels): + logger.debug("[%s] Ignoring message in ignored channel: %s", self.name, channel_keys) return free_channels = self._discord_free_response_channels() - if parent_channel_id: - channel_ids.add(parent_channel_id) require_mention = self._discord_require_mention() # Voice-linked text channels act as free-response while voice is active. @@ -5378,7 +5490,7 @@ class DiscordAdapter(BasePlatformAdapter): is_voice_linked_channel = current_channel_id in voice_linked_ids is_free_channel = ( "*" in free_channels - or bool(channel_ids & free_channels) + or bool(channel_keys & free_channels) or is_voice_linked_channel ) @@ -5404,7 +5516,7 @@ class DiscordAdapter(BasePlatformAdapter): if not is_thread and not isinstance(message.channel, discord.DMChannel): no_thread_channels_raw = os.getenv("DISCORD_NO_THREAD_CHANNELS", "") no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()} - skip_thread = bool(channel_ids & no_thread_channels) or is_free_channel + skip_thread = bool(channel_keys & no_thread_channels) or is_free_channel auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in {"true", "1", "yes"} is_reply_message = getattr(message, "type", None) == discord.MessageType.reply if auto_thread and not skip_thread and not is_voice_linked_channel and not is_reply_message: diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index fbf7fc56a..1c71ac641 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -410,6 +410,74 @@ async def test_discord_reply_message_skips_auto_thread(adapter, monkeypatch): assert event.source.chat_type == "group" +@pytest.mark.asyncio +async def test_discord_free_response_matches_channel_name(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "cypher") + monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") + + message = make_message( + channel=FakeTextChannel(channel_id=123, name="cypher"), + content="name-configured channel without mention", + ) + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "name-configured channel without mention" + + +@pytest.mark.asyncio +async def test_discord_free_response_matches_hash_channel_name(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "#cypher") + monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") + + message = make_message( + channel=FakeTextChannel(channel_id=123, name="cypher"), + content="hash-name-configured channel without mention", + ) + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_discord_parent_channel_name_matches_thread_gates(adapter, monkeypatch): + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.setenv("DISCORD_FREE_RESPONSE_CHANNELS", "#cypher") + monkeypatch.setenv("DISCORD_AUTO_THREAD", "false") + + parent = FakeTextChannel(channel_id=123, name="cypher") + thread = FakeThread(channel_id=456, name="topic", parent=parent) + message = make_message(channel=thread, content="thread message without mention") + + await adapter._handle_message(message) + + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.source.thread_id == "456" + + +@pytest.mark.asyncio +async def test_discord_no_thread_matches_channel_name(adapter, monkeypatch): + monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False) + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false") + monkeypatch.setenv("DISCORD_NO_THREAD_CHANNELS", "cypher") + + adapter._auto_create_thread = AsyncMock() + message = make_message(channel=FakeTextChannel(channel_id=123, name="cypher"), content="hello") + + await adapter._handle_message(message) + + adapter._auto_create_thread.assert_not_awaited() + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.source.chat_type == "group" + + @pytest.mark.asyncio async def test_discord_auto_thread_can_be_disabled(adapter, monkeypatch): """Setting auto_thread to false skips thread creation.""" diff --git a/tests/gateway/test_discord_pending_text_batch_shutdown.py b/tests/gateway/test_discord_pending_text_batch_shutdown.py new file mode 100644 index 000000000..d89b65f58 --- /dev/null +++ b/tests/gateway/test_discord_pending_text_batch_shutdown.py @@ -0,0 +1,75 @@ +"""Regression guard for Discord text-batch flush during gateway shutdown.""" + +import asyncio +import sys +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import MessageEvent, MessageType +from gateway.session import SessionSource + + +def _ensure_discord_mock(): + """Install a mock discord module when discord.py isn't available.""" + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.Client = MagicMock + discord_mod.File = MagicMock + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) + discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3) + discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5) + discord_mod.Interaction = object + discord_mod.Embed = MagicMock + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402 + + +@pytest.mark.asyncio +async def test_cancel_background_tasks_awaits_pending_text_batch_before_clearing(): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="fake-token")) + flushed = asyncio.Event() + + async def pending_flush(): + await asyncio.sleep(0) + flushed.set() + + task = asyncio.create_task(pending_flush()) + adapter._pending_text_batch_tasks["chat"] = task + adapter._pending_text_batches["chat"] = MessageEvent( + text="pending", + message_type=MessageType.TEXT, + source=SessionSource(platform=Platform.DISCORD, chat_id="chat", chat_type="group"), + ) + + await adapter.cancel_background_tasks() + + assert flushed.is_set() + assert task.done() + assert adapter._pending_text_batch_tasks == {} + assert adapter._pending_text_batches == {}