fix(gateway): await async post-delivery callbacks in chained wrapper
When two features register a post-delivery callback for the same session (e.g. background-review release + /goal continuation), the second registration is composed with the first via a `_chained` wrapper. That wrapper was `def _chained()` — a sync function calling each callback via `_prev()` / `_new()` and discarding the return value. For sync callbacks that's fine. For async callbacks (such as the `_deliver()` coroutine the /goal feature registers to inject the continuation prompt) the returned coroutine was silently dropped: RuntimeWarning: coroutine '_deliver' was never awaited. Outer invoker in `_handle_message` already checks `inspect.isawaitable(_post_result)` and awaits — but only sees the wrapper's return value, which was `None`. Fix: make `_chained` async, iterate over chained callbacks, await any that return an awaitable. Outer invoker already handles awaitable wrappers, so no other change is needed. Tested: * Added two regression tests in test_post_delivery_callback_chaining.py covering an async callback chained behind sync (and vice versa). * Updated existing chaining tests + test_run_cleanup_progress.py to await the popped callback when it's awaitable. * 62 tests pass across the touched suites. Live-validated on Discord: /goal continuations now arrive after the first turn's response is delivered (previously silent). Refs: NousResearch/hermes-agent#31922
This commit is contained in:
parent
8b14080e30
commit
74d2660aeb
3 changed files with 98 additions and 19 deletions
|
|
@ -3901,15 +3901,22 @@ class BasePlatformAdapter(ABC):
|
|||
_prev = existing_cb
|
||||
_new = callback
|
||||
|
||||
def _chained() -> None:
|
||||
try:
|
||||
_prev()
|
||||
except Exception:
|
||||
logger.debug("Post-delivery callback failed", exc_info=True)
|
||||
try:
|
||||
_new()
|
||||
except Exception:
|
||||
logger.debug("Post-delivery callback failed", exc_info=True)
|
||||
async def _chained() -> None:
|
||||
# Both _prev and _new may be sync or async. The chained
|
||||
# wrapper itself must be async because the outer invoker
|
||||
# (``_handle_message`` etc.) awaits awaitable callbacks; a
|
||||
# sync wrapper here would call ``_prev()`` / ``_new()`` and
|
||||
# silently drop any returned coroutine, breaking chained
|
||||
# async post-delivery hooks (e.g. ``/goal`` continuations).
|
||||
for _cb in (_prev, _new):
|
||||
try:
|
||||
_result = _cb()
|
||||
if inspect.isawaitable(_result):
|
||||
await _result
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Post-delivery callback failed", exc_info=True
|
||||
)
|
||||
|
||||
callback = _chained
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,15 @@ session (e.g. background-review release + temporary-progress cleanup), the
|
|||
registration API chains them rather than clobbering. Per-callback
|
||||
exceptions are swallowed so one bad callback can't sabotage the others.
|
||||
Stale-generation registrations are rejected.
|
||||
|
||||
The chained wrapper is ``async`` so it transparently supports sync or async
|
||||
callbacks — the outer invoker in ``_handle_message`` awaits awaitable
|
||||
callbacks, and a sync wrapper would silently drop coroutine results from
|
||||
async callbacks chained behind it.
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
|
@ -31,12 +39,25 @@ def adapter():
|
|||
return _MinAdapter(PlatformConfig(enabled=True), Platform.TELEGRAM)
|
||||
|
||||
|
||||
def _invoke(cb):
|
||||
"""Invoke a popped callback, awaiting if it returns a coroutine.
|
||||
|
||||
Single-registration callbacks are returned as the raw user callable
|
||||
(sync). Chained callbacks (two or more registrations on the same
|
||||
session) are wrapped in an async helper. Tests use this helper so
|
||||
they don't have to care which case they're exercising.
|
||||
"""
|
||||
result = cb()
|
||||
if inspect.isawaitable(result):
|
||||
asyncio.run(result)
|
||||
|
||||
|
||||
class TestPostDeliveryCallbackChaining:
|
||||
def test_single_callback_fires(self, adapter):
|
||||
fired = []
|
||||
adapter.register_post_delivery_callback("s", lambda: fired.append("A"))
|
||||
cb = adapter.pop_post_delivery_callback("s")
|
||||
cb()
|
||||
_invoke(cb)
|
||||
assert fired == ["A"]
|
||||
|
||||
def test_two_callbacks_chain_in_order(self, adapter):
|
||||
|
|
@ -44,7 +65,7 @@ class TestPostDeliveryCallbackChaining:
|
|||
adapter.register_post_delivery_callback("s", lambda: fired.append("A"))
|
||||
adapter.register_post_delivery_callback("s", lambda: fired.append("B"))
|
||||
cb = adapter.pop_post_delivery_callback("s")
|
||||
cb()
|
||||
_invoke(cb)
|
||||
assert fired == ["A", "B"]
|
||||
|
||||
def test_three_callbacks_chain_in_order(self, adapter):
|
||||
|
|
@ -55,7 +76,7 @@ class TestPostDeliveryCallbackChaining:
|
|||
"s", lambda x=label: fired.append(x)
|
||||
)
|
||||
cb = adapter.pop_post_delivery_callback("s")
|
||||
cb()
|
||||
_invoke(cb)
|
||||
assert fired == ["A", "B", "C"]
|
||||
|
||||
def test_exception_in_one_callback_does_not_block_next(self, adapter):
|
||||
|
|
@ -67,7 +88,7 @@ class TestPostDeliveryCallbackChaining:
|
|||
adapter.register_post_delivery_callback("s", boom)
|
||||
adapter.register_post_delivery_callback("s", lambda: fired.append("survived"))
|
||||
cb = adapter.pop_post_delivery_callback("s")
|
||||
cb()
|
||||
_invoke(cb)
|
||||
assert fired == ["survived"]
|
||||
|
||||
def test_same_generation_chains(self, adapter):
|
||||
|
|
@ -79,7 +100,7 @@ class TestPostDeliveryCallbackChaining:
|
|||
"s", lambda: fired.append("B"), generation=5
|
||||
)
|
||||
cb = adapter.pop_post_delivery_callback("s", generation=5)
|
||||
cb()
|
||||
_invoke(cb)
|
||||
assert fired == ["A", "B"]
|
||||
|
||||
def test_stale_generation_registration_rejected(self, adapter):
|
||||
|
|
@ -93,7 +114,7 @@ class TestPostDeliveryCallbackChaining:
|
|||
"s", lambda: fired.append("stale_gen3"), generation=3
|
||||
)
|
||||
cb = adapter.pop_post_delivery_callback("s", generation=7)
|
||||
cb()
|
||||
_invoke(cb)
|
||||
assert fired == ["gen7"]
|
||||
|
||||
def test_pop_at_wrong_generation_returns_none(self, adapter):
|
||||
|
|
@ -111,3 +132,42 @@ class TestPostDeliveryCallbackChaining:
|
|||
def test_non_callable_is_noop(self, adapter):
|
||||
adapter.register_post_delivery_callback("s", "not-callable") # type: ignore[arg-type]
|
||||
assert adapter._post_delivery_callbacks == {}
|
||||
|
||||
|
||||
class TestPostDeliveryCallbackAsyncChaining:
|
||||
"""When an async callback is chained, the wrapper must await it.
|
||||
|
||||
Regression test for a bug where the sync ``_chained`` wrapper called
|
||||
async callbacks without awaiting, silently dropping the returned
|
||||
coroutine. This broke ``/goal`` continuations (Discord etc.) where
|
||||
the continuation injection is an async ``_deliver()`` coroutine.
|
||||
"""
|
||||
|
||||
def test_async_callback_in_chain_is_awaited(self, adapter):
|
||||
fired = []
|
||||
|
||||
async def async_cb():
|
||||
await asyncio.sleep(0)
|
||||
fired.append("async")
|
||||
|
||||
adapter.register_post_delivery_callback("s", lambda: fired.append("sync"))
|
||||
adapter.register_post_delivery_callback("s", async_cb)
|
||||
cb = adapter.pop_post_delivery_callback("s")
|
||||
_invoke(cb)
|
||||
assert fired == ["sync", "async"]
|
||||
|
||||
def test_two_async_callbacks_both_awaited(self, adapter):
|
||||
fired = []
|
||||
|
||||
def make(label):
|
||||
async def _cb():
|
||||
await asyncio.sleep(0)
|
||||
fired.append(label)
|
||||
|
||||
return _cb
|
||||
|
||||
adapter.register_post_delivery_callback("s", make("A"))
|
||||
adapter.register_post_delivery_callback("s", make("B"))
|
||||
cb = adapter.pop_post_delivery_callback("s")
|
||||
_invoke(cb)
|
||||
assert fired == ["A", "B"]
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ Adapters without ``delete_message`` silently no-op.
|
|||
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect as _inspect
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
|
|
@ -20,6 +21,17 @@ from types import SimpleNamespace
|
|||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
async def _fire_post_delivery_cb(cb):
|
||||
"""Invoke a popped post-delivery callback, awaiting if it's async.
|
||||
|
||||
Chained registrations return an async wrapper; single registrations
|
||||
return the raw sync callable. Either way, await any awaitable result.
|
||||
"""
|
||||
result = cb()
|
||||
if _inspect.isawaitable(result):
|
||||
await result
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
|
@ -215,7 +227,7 @@ async def test_cleanup_off_by_default_leaves_bubbles(monkeypatch, tmp_path):
|
|||
# delete_message calls when cleanup is off.
|
||||
cb = adapter.pop_post_delivery_callback(session_key)
|
||||
if cb is not None:
|
||||
cb()
|
||||
await _fire_post_delivery_cb(cb)
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.01)
|
||||
assert adapter.deleted == []
|
||||
|
|
@ -248,7 +260,7 @@ async def test_cleanup_registers_callback_and_deletes_on_success(monkeypatch, tm
|
|||
|
||||
# Fire it (base.py does this in _process_message_background's finally)
|
||||
# and let the scheduled coroutine run to completion.
|
||||
cb()
|
||||
await _fire_post_delivery_cb(cb)
|
||||
# delete_message is scheduled via run_coroutine_threadsafe → give the
|
||||
# loop a couple of ticks to drain.
|
||||
for _ in range(20):
|
||||
|
|
@ -287,7 +299,7 @@ async def test_cleanup_skipped_on_failed_run(monkeypatch, tmp_path):
|
|||
# the cleanup callback is skipped on failed runs.
|
||||
cb = adapter.pop_post_delivery_callback(session_key)
|
||||
if cb is not None:
|
||||
cb()
|
||||
await _fire_post_delivery_cb(cb)
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.01)
|
||||
assert adapter.deleted == []
|
||||
|
|
@ -355,7 +367,7 @@ async def test_cleanup_chains_with_existing_callback(monkeypatch, tmp_path):
|
|||
assert result["final_response"] == "done"
|
||||
cb = adapter.pop_post_delivery_callback(session_key)
|
||||
assert callable(cb)
|
||||
cb()
|
||||
await _fire_post_delivery_cb(cb)
|
||||
for _ in range(20):
|
||||
await asyncio.sleep(0.01)
|
||||
if adapter.deleted:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue