diff --git a/tests/tools/test_mcp_oauth_manager.py b/tests/tools/test_mcp_oauth_manager.py index 5554f245e..2e7d3aa41 100644 --- a/tests/tools/test_mcp_oauth_manager.py +++ b/tests/tools/test_mcp_oauth_manager.py @@ -134,6 +134,101 @@ async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch): assert provider._initialized is False +@pytest.mark.asyncio +async def test_handle_401_tracks_inflight_task_to_prevent_gc(tmp_path, monkeypatch): + """The 401 handler task must be strongly referenced by the manager. + + ``asyncio.create_task`` returns a task the event loop only weakly + references. If the manager discards its handle, the background coroutine + can be garbage-collected mid-run and every concurrent waiter stuck on + ``await pending`` hangs forever. See the design note on + ``MCPOAuthManager._inflight_tasks``. + """ + import asyncio + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry + + class _TrackedSet(set): + """set subclass that records every element ever inserted.""" + + def __init__(self): + super().__init__() + self.ever_added: list = [] + + def add(self, item): # noqa: A003 + self.ever_added.append(item) + super().add(item) + + mgr = MCPOAuthManager() + mgr._inflight_tasks = _TrackedSet() + + class _DummyProvider: + context = None # forces the can_refresh=False branch + + mgr._entries["srv"] = _ProviderEntry( + server_url="https://example.com/mcp", + oauth_config=None, + provider=_DummyProvider(), + ) + + result = await mgr.handle_401("srv", failed_access_token="TOK") + + # Exactly one handler task was created and tracked. + assert len(mgr._inflight_tasks.ever_added) == 1 + tracked_task = mgr._inflight_tasks.ever_added[0] + assert isinstance(tracked_task, asyncio.Task) + # done_callback must have removed the finished task from the live set, + # otherwise the set would grow unbounded across repeated 401s. + assert tracked_task not in mgr._inflight_tasks + assert len(mgr._inflight_tasks) == 0 + assert tracked_task.done() + # With provider.context=None, there's nothing to refresh — result False. + assert result is False + + +@pytest.mark.asyncio +async def test_handle_401_dedup_survives_even_if_task_reference_dropped(tmp_path, monkeypatch): + """Concurrent 401s share one handler task and all callers resolve. + + Regression guard: if the manager ever stops holding a strong reference + to the `_do_handle` task, this test can intermittently hang when the + task is GC'd between the ``await`` checkpoints inside ``_do_handle``. + Running it in CI with ``gc.collect()`` mid-flight (below) exercises + that window. + """ + import asyncio + import gc + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, _ProviderEntry + + mgr = MCPOAuthManager() + + class _DummyProvider: + context = None + + mgr._entries["srv"] = _ProviderEntry( + server_url="https://example.com/mcp", + oauth_config=None, + provider=_DummyProvider(), + ) + + # Fan out N concurrent callers sharing the same failed token so all + # collapse onto a single deduped handler future. + async def _caller(): + return await mgr.handle_401("srv", failed_access_token="TOK") + + tasks = [asyncio.create_task(_caller()) for _ in range(8)] + # Give the event loop one tick to schedule _do_handle, then force GC. + await asyncio.sleep(0) + gc.collect() + + results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=5.0) + assert results == [False] * 8 + assert len(mgr._inflight_tasks) == 0 + + def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch): """get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider.""" from tools.mcp_oauth_manager import ( diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py index 1011c16bd..8fe1c66f8 100644 --- a/tools/mcp_oauth_manager.py +++ b/tools/mcp_oauth_manager.py @@ -451,6 +451,10 @@ class MCPOAuthManager: def __init__(self) -> None: self._entries: dict[str, _ProviderEntry] = {} self._entries_lock = threading.Lock() + # Holds strong references to in-flight 401 handler tasks so the + # event loop's weak-reference bookkeeping cannot GC them mid-run + # and leave `await pending` waiters hanging forever. + self._inflight_tasks: set[asyncio.Task] = set() # -- Provider construction / caching ------------------------------------- @@ -677,7 +681,9 @@ class MCPOAuthManager: finally: entry.pending_401.pop(key, None) - asyncio.create_task(_do_handle()) + task = asyncio.create_task(_do_handle()) + self._inflight_tasks.add(task) + task.add_done_callback(self._inflight_tasks.discard) try: return await pending