fix(agent): route restore custom-pool match through canonical helper

Follow-up on the salvaged #56392 guard. The cherry-picked change matched
custom:<name> pool entries against the primary by raw base_url string
equality, which (a) can't disambiguate two named custom providers sharing
one gateway base_url and (b) left a latent bare-"custom" entry bypass.

Route the match through get_custom_provider_pool_key(rt[base_url]) compared
against the entry's custom:<name> key, mirroring the sibling guard in
recover_with_credential_pool. Use CUSTOM_POOL_PREFIX instead of the literal.

Add regression tests for the custom same-endpoint (swap) and cross-endpoint
(skip) branches, plus the plain-provider fallback-pool case from #56885.
This commit is contained in:
kshitijk4poor 2026-07-02 13:36:04 +05:30 committed by kshitij
parent 820a052575
commit b837f07dcd
2 changed files with 156 additions and 8 deletions

View file

@ -1187,14 +1187,29 @@ def restore_primary_runtime(agent) -> bool:
entry_provider = str(getattr(entry, "provider", "") or "").strip().lower()
primary_provider = str(rt.get("provider") or "").strip().lower()
entry_matches_primary = entry_provider == primary_provider
if primary_provider == "custom" and entry_provider.startswith("custom:"):
primary_base_url = str(rt.get("base_url") or "").strip().rstrip("/").lower()
entry_base_url = str(
getattr(entry, "runtime_base_url", None)
or getattr(entry, "base_url", None)
or ""
).strip().rstrip("/").lower()
entry_matches_primary = bool(primary_base_url and entry_base_url == primary_base_url)
# Custom endpoints all carry the generic ``custom`` provider on
# the agent while the pool entry is keyed ``custom:<name>`` (see
# CUSTOM_POOL_PREFIX). Resolve the primary's base_url to its
# ``custom:<name>`` key via the canonical helper and compare
# against the entry's key — this mirrors the sibling guard in
# ``recover_with_credential_pool`` (see above) and correctly
# disambiguates multiple custom providers that share one gateway
# base_url. Fixes #56885.
from agent.credential_pool import CUSTOM_POOL_PREFIX
if (
primary_provider == "custom"
and entry_provider.startswith(CUSTOM_POOL_PREFIX)
):
entry_matches_primary = False
try:
from agent.credential_pool import get_custom_provider_pool_key
primary_base_url = str(rt.get("base_url") or "").strip()
primary_key = (
get_custom_provider_pool_key(primary_base_url) or ""
).strip().lower()
entry_matches_primary = bool(primary_key) and primary_key == entry_provider
except Exception:
entry_matches_primary = False
entry_key = (
getattr(entry, "runtime_api_key", None)

View file

@ -241,6 +241,139 @@ class TestRestorePrimaryRuntime:
assert agent.base_url == original_base_url
agent._swap_credential.assert_not_called()
def test_restore_keeps_primary_base_url_when_fallback_pool_attached(self):
"""Issue #56885: plain-provider primary must not inherit a fallback
provider's base_url via the restore-path pool reselect.
Repro: primary is openai-api/gpt-5.5, a transient failure falls back to
deepseek and attaches deepseek's credential pool. On the next turn the
restore reselect must NOT swap in the deepseek entry otherwise the
request goes out as model=gpt-5.5 to base_url=api.deepseek.com 404.
"""
class _DeepseekEntry:
provider = "deepseek"
id = "dsk-1"
label = "deepseek-key"
runtime_api_key = "sk-deepseek-xxx"
runtime_base_url = "https://api.deepseek.com/v1"
base_url = "https://api.deepseek.com/v1"
access_token = "sk-deepseek-xxx"
class _DeepseekPool:
provider = "deepseek"
def has_available(self):
return True
def select(self):
return _DeepseekEntry()
agent = _make_agent(
provider="openai-api",
base_url="https://api.openai.com/v1",
fallback_model={"provider": "deepseek", "model": "deepseek-v4-flash"},
)
primary_base_url = agent.base_url
primary_provider = agent.provider
mock_client = _mock_resolve(base_url="https://api.deepseek.com/v1")
with patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(mock_client, None),
):
agent._try_activate_fallback()
# Fallback attached deepseek's pool; simulate it surviving into the next turn.
agent._credential_pool = _DeepseekPool()
agent._swap_credential = MagicMock()
with patch("run_agent.OpenAI", return_value=MagicMock()):
result = agent._restore_primary_runtime()
assert result is True
assert agent.provider == primary_provider
assert agent.base_url == primary_base_url
assert "deepseek" not in str(agent.base_url)
agent._swap_credential.assert_not_called()
def test_restore_swaps_matching_custom_pool_entry(self):
"""Custom primary + custom:<name> entry whose base_url resolves to the
SAME custom key must swap (legitimate same-endpoint rotation)."""
class _Entry:
provider = "custom:myllm"
id = "custom-entry"
label = "myllm"
runtime_api_key = "custom-key"
runtime_base_url = "https://my-llm.example.com/v1"
access_token = "custom-key"
class _Pool:
provider = "custom:myllm"
def has_available(self):
return True
def select(self):
return _Entry()
agent = _make_agent(provider="custom", base_url="https://my-llm.example.com/v1")
agent._fallback_activated = True
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
with (
patch(
"agent.credential_pool.get_custom_provider_pool_key",
return_value="custom:myllm",
),
patch("run_agent.OpenAI", return_value=MagicMock()),
):
result = agent._restore_primary_runtime()
assert result is True
agent._swap_credential.assert_called_once()
def test_restore_skips_cross_endpoint_custom_pool_entry(self):
"""Custom primary + custom:<name> entry whose base_url resolves to a
DIFFERENT custom key must skip two named custom providers sharing a
gateway must not cross-contaminate."""
class _Entry:
provider = "custom:otherllm"
id = "other-entry"
label = "otherllm"
runtime_api_key = "other-key"
runtime_base_url = "https://my-llm.example.com/v1"
access_token = "other-key"
class _Pool:
provider = "custom:otherllm"
def has_available(self):
return True
def select(self):
return _Entry()
agent = _make_agent(provider="custom", base_url="https://my-llm.example.com/v1")
agent._fallback_activated = True
original_base_url = agent.base_url
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
with (
patch(
"agent.credential_pool.get_custom_provider_pool_key",
return_value="custom:myllm", # primary resolves to a DIFFERENT key
),
patch("run_agent.OpenAI", return_value=MagicMock()),
):
result = agent._restore_primary_runtime()
assert result is True
assert agent.base_url == original_base_url
agent._swap_credential.assert_not_called()
def test_restore_survives_exception(self):
"""If client rebuild fails, the method returns False gracefully."""
agent = _make_agent()