From 9f22f36625fb030adba767b55bb9f4dc3472f514 Mon Sep 17 00:00:00 2001 From: haileymarshall Date: Sat, 18 Apr 2026 18:13:04 +0100 Subject: [PATCH] fix(mcp-oauth): anchor 401 handler task to prevent GC mid-flight MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `handle_401` spawned a dedup'd recovery coroutine via `asyncio.create_task(_do_handle())` and discarded the returned task reference. Python's event loop only keeps weak references to tasks, so the coroutine could be garbage-collected before it called `pending.set_result(...)`. Every concurrent caller awaiting that future then hangs forever, and the `finally: entry.pending_401.pop(...)` cleanup never runs — so subsequent 401s for the same key latch onto the dead future too. Same pattern the adapter-side fixes address (#11997, #11998, #12000, #12001, #12006). Hold the task in a process-wide set on the manager and discard it via `add_done_callback` once it completes. Regression test covers both the structural invariant (task tracked, then removed on completion) and a concurrent dedup path with a forced `gc.collect()` between the handler's await points. --- tests/tools/test_mcp_oauth_manager.py | 95 +++++++++++++++++++++++++++ tools/mcp_oauth_manager.py | 8 ++- 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/tests/tools/test_mcp_oauth_manager.py b/tests/tools/test_mcp_oauth_manager.py index 5554f245e..2e7d3aa41 100644 --- a/tests/tools/test_mcp_oauth_manager.py +++ b/tests/tools/test_mcp_oauth_manager.py @@ -134,6 +134,101 @@ async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch): assert provider._initialized is False +@pytest.mark.asyncio +async def test_handle_401_tracks_inflight_task_to_prevent_gc(tmp_path, monkeypatch): + """The 401 handler task must be strongly referenced by the manager. + + ``asyncio.create_task`` returns a task the event loop only weakly + references. If the manager discards its handle, the background coroutine + can be garbage-collected mid-run and every concurrent waiter stuck on + ``await pending`` hangs forever. See the design note on + ``MCPOAuthManager._inflight_tasks``. + """ + import asyncio + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry + + class _TrackedSet(set): + """set subclass that records every element ever inserted.""" + + def __init__(self): + super().__init__() + self.ever_added: list = [] + + def add(self, item): # noqa: A003 + self.ever_added.append(item) + super().add(item) + + mgr = MCPOAuthManager() + mgr._inflight_tasks = _TrackedSet() + + class _DummyProvider: + context = None # forces the can_refresh=False branch + + mgr._entries["srv"] = _ProviderEntry( + server_url="https://example.com/mcp", + oauth_config=None, + provider=_DummyProvider(), + ) + + result = await mgr.handle_401("srv", failed_access_token="TOK") + + # Exactly one handler task was created and tracked. + assert len(mgr._inflight_tasks.ever_added) == 1 + tracked_task = mgr._inflight_tasks.ever_added[0] + assert isinstance(tracked_task, asyncio.Task) + # done_callback must have removed the finished task from the live set, + # otherwise the set would grow unbounded across repeated 401s. + assert tracked_task not in mgr._inflight_tasks + assert len(mgr._inflight_tasks) == 0 + assert tracked_task.done() + # With provider.context=None, there's nothing to refresh — result False. + assert result is False + + +@pytest.mark.asyncio +async def test_handle_401_dedup_survives_even_if_task_reference_dropped(tmp_path, monkeypatch): + """Concurrent 401s share one handler task and all callers resolve. + + Regression guard: if the manager ever stops holding a strong reference + to the `_do_handle` task, this test can intermittently hang when the + task is GC'd between the ``await`` checkpoints inside ``_do_handle``. + Running it in CI with ``gc.collect()`` mid-flight (below) exercises + that window. + """ + import asyncio + import gc + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry + + mgr = MCPOAuthManager() + + class _DummyProvider: + context = None + + mgr._entries["srv"] = _ProviderEntry( + server_url="https://example.com/mcp", + oauth_config=None, + provider=_DummyProvider(), + ) + + # Fan out N concurrent callers sharing the same failed token so all + # collapse onto a single deduped handler future. + async def _caller(): + return await mgr.handle_401("srv", failed_access_token="TOK") + + tasks = [asyncio.create_task(_caller()) for _ in range(8)] + # Give the event loop one tick to schedule _do_handle, then force GC. + await asyncio.sleep(0) + gc.collect() + + results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=5.0) + assert results == [False] * 8 + assert len(mgr._inflight_tasks) == 0 + + def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch): """get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider.""" from tools.mcp_oauth_manager import ( diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py index 1011c16bd..8fe1c66f8 100644 --- a/tools/mcp_oauth_manager.py +++ b/tools/mcp_oauth_manager.py @@ -451,6 +451,10 @@ class MCPOAuthManager: def __init__(self) -> None: self._entries: dict[str, _ProviderEntry] = {} self._entries_lock = threading.Lock() + # Holds strong references to in-flight 401 handler tasks so the + # event loop's weak-reference bookkeeping cannot GC them mid-run + # and leave `await pending` waiters hanging forever. + self._inflight_tasks: set[asyncio.Task] = set() # -- Provider construction / caching ------------------------------------- @@ -677,7 +681,9 @@ class MCPOAuthManager: finally: entry.pending_401.pop(key, None) - asyncio.create_task(_do_handle()) + task = asyncio.create_task(_do_handle()) + self._inflight_tasks.add(task) + task.add_done_callback(self._inflight_tasks.discard) try: return await pending