diff --git a/agent/agent_runtime_helpers.py b/agent/agent_runtime_helpers.py index 9c7a04629..18ed3102c 100644 --- a/agent/agent_runtime_helpers.py +++ b/agent/agent_runtime_helpers.py @@ -1187,14 +1187,29 @@ def restore_primary_runtime(agent) -> bool: entry_provider = str(getattr(entry, "provider", "") or "").strip().lower() primary_provider = str(rt.get("provider") or "").strip().lower() entry_matches_primary = entry_provider == primary_provider - if primary_provider == "custom" and entry_provider.startswith("custom:"): - primary_base_url = str(rt.get("base_url") or "").strip().rstrip("/").lower() - entry_base_url = str( - getattr(entry, "runtime_base_url", None) - or getattr(entry, "base_url", None) - or "" - ).strip().rstrip("/").lower() - entry_matches_primary = bool(primary_base_url and entry_base_url == primary_base_url) + # Custom endpoints all carry the generic ``custom`` provider on + # the agent while the pool entry is keyed ``custom:`` (see + # CUSTOM_POOL_PREFIX). Resolve the primary's base_url to its + # ``custom:`` key via the canonical helper and compare + # against the entry's key — this mirrors the sibling guard in + # ``recover_with_credential_pool`` (see above) and correctly + # disambiguates multiple custom providers that share one gateway + # base_url. Fixes #56885. + from agent.credential_pool import CUSTOM_POOL_PREFIX + if ( + primary_provider == "custom" + and entry_provider.startswith(CUSTOM_POOL_PREFIX) + ): + entry_matches_primary = False + try: + from agent.credential_pool import get_custom_provider_pool_key + primary_base_url = str(rt.get("base_url") or "").strip() + primary_key = ( + get_custom_provider_pool_key(primary_base_url) or "" + ).strip().lower() + entry_matches_primary = bool(primary_key) and primary_key == entry_provider + except Exception: + entry_matches_primary = False entry_key = ( getattr(entry, "runtime_api_key", None) diff --git a/tests/run_agent/test_primary_runtime_restore.py b/tests/run_agent/test_primary_runtime_restore.py index 06af15bbc..d1ac56dca 100644 --- a/tests/run_agent/test_primary_runtime_restore.py +++ b/tests/run_agent/test_primary_runtime_restore.py @@ -241,6 +241,139 @@ class TestRestorePrimaryRuntime: assert agent.base_url == original_base_url agent._swap_credential.assert_not_called() + def test_restore_keeps_primary_base_url_when_fallback_pool_attached(self): + """Issue #56885: plain-provider primary must not inherit a fallback + provider's base_url via the restore-path pool reselect. + + Repro: primary is openai-api/gpt-5.5, a transient failure falls back to + deepseek and attaches deepseek's credential pool. On the next turn the + restore reselect must NOT swap in the deepseek entry — otherwise the + request goes out as model=gpt-5.5 to base_url=api.deepseek.com → 404. + """ + + class _DeepseekEntry: + provider = "deepseek" + id = "dsk-1" + label = "deepseek-key" + runtime_api_key = "sk-deepseek-xxx" + runtime_base_url = "https://api.deepseek.com/v1" + base_url = "https://api.deepseek.com/v1" + access_token = "sk-deepseek-xxx" + + class _DeepseekPool: + provider = "deepseek" + + def has_available(self): + return True + + def select(self): + return _DeepseekEntry() + + agent = _make_agent( + provider="openai-api", + base_url="https://api.openai.com/v1", + fallback_model={"provider": "deepseek", "model": "deepseek-v4-flash"}, + ) + primary_base_url = agent.base_url + primary_provider = agent.provider + mock_client = _mock_resolve(base_url="https://api.deepseek.com/v1") + with patch( + "agent.auxiliary_client.resolve_provider_client", + return_value=(mock_client, None), + ): + agent._try_activate_fallback() + # Fallback attached deepseek's pool; simulate it surviving into the next turn. + agent._credential_pool = _DeepseekPool() + agent._swap_credential = MagicMock() + + with patch("run_agent.OpenAI", return_value=MagicMock()): + result = agent._restore_primary_runtime() + + assert result is True + assert agent.provider == primary_provider + assert agent.base_url == primary_base_url + assert "deepseek" not in str(agent.base_url) + agent._swap_credential.assert_not_called() + + def test_restore_swaps_matching_custom_pool_entry(self): + """Custom primary + custom: entry whose base_url resolves to the + SAME custom key must swap (legitimate same-endpoint rotation).""" + + class _Entry: + provider = "custom:myllm" + id = "custom-entry" + label = "myllm" + runtime_api_key = "custom-key" + runtime_base_url = "https://my-llm.example.com/v1" + access_token = "custom-key" + + class _Pool: + provider = "custom:myllm" + + def has_available(self): + return True + + def select(self): + return _Entry() + + agent = _make_agent(provider="custom", base_url="https://my-llm.example.com/v1") + agent._fallback_activated = True + agent._credential_pool = _Pool() + agent._swap_credential = MagicMock() + + with ( + patch( + "agent.credential_pool.get_custom_provider_pool_key", + return_value="custom:myllm", + ), + patch("run_agent.OpenAI", return_value=MagicMock()), + ): + result = agent._restore_primary_runtime() + + assert result is True + agent._swap_credential.assert_called_once() + + def test_restore_skips_cross_endpoint_custom_pool_entry(self): + """Custom primary + custom: entry whose base_url resolves to a + DIFFERENT custom key must skip — two named custom providers sharing a + gateway must not cross-contaminate.""" + + class _Entry: + provider = "custom:otherllm" + id = "other-entry" + label = "otherllm" + runtime_api_key = "other-key" + runtime_base_url = "https://my-llm.example.com/v1" + access_token = "other-key" + + class _Pool: + provider = "custom:otherllm" + + def has_available(self): + return True + + def select(self): + return _Entry() + + agent = _make_agent(provider="custom", base_url="https://my-llm.example.com/v1") + agent._fallback_activated = True + original_base_url = agent.base_url + agent._credential_pool = _Pool() + agent._swap_credential = MagicMock() + + with ( + patch( + "agent.credential_pool.get_custom_provider_pool_key", + return_value="custom:myllm", # primary resolves to a DIFFERENT key + ), + patch("run_agent.OpenAI", return_value=MagicMock()), + ): + result = agent._restore_primary_runtime() + + assert result is True + assert agent.base_url == original_base_url + agent._swap_credential.assert_not_called() + def test_restore_survives_exception(self): """If client rebuild fails, the method returns False gracefully.""" agent = _make_agent()