fix(agent): route restore custom-pool match through canonical helper
Follow-up on the salvaged #56392 guard. The cherry-picked change matched custom:<name> pool entries against the primary by raw base_url string equality, which (a) can't disambiguate two named custom providers sharing one gateway base_url and (b) left a latent bare-"custom" entry bypass. Route the match through get_custom_provider_pool_key(rt[base_url]) compared against the entry's custom:<name> key, mirroring the sibling guard in recover_with_credential_pool. Use CUSTOM_POOL_PREFIX instead of the literal. Add regression tests for the custom same-endpoint (swap) and cross-endpoint (skip) branches, plus the plain-provider fallback-pool case from #56885.
This commit is contained in:
parent
820a052575
commit
b837f07dcd
2 changed files with 156 additions and 8 deletions
|
|
@ -1187,14 +1187,29 @@ def restore_primary_runtime(agent) -> bool:
|
|||
entry_provider = str(getattr(entry, "provider", "") or "").strip().lower()
|
||||
primary_provider = str(rt.get("provider") or "").strip().lower()
|
||||
entry_matches_primary = entry_provider == primary_provider
|
||||
if primary_provider == "custom" and entry_provider.startswith("custom:"):
|
||||
primary_base_url = str(rt.get("base_url") or "").strip().rstrip("/").lower()
|
||||
entry_base_url = str(
|
||||
getattr(entry, "runtime_base_url", None)
|
||||
or getattr(entry, "base_url", None)
|
||||
or ""
|
||||
).strip().rstrip("/").lower()
|
||||
entry_matches_primary = bool(primary_base_url and entry_base_url == primary_base_url)
|
||||
# Custom endpoints all carry the generic ``custom`` provider on
|
||||
# the agent while the pool entry is keyed ``custom:<name>`` (see
|
||||
# CUSTOM_POOL_PREFIX). Resolve the primary's base_url to its
|
||||
# ``custom:<name>`` key via the canonical helper and compare
|
||||
# against the entry's key — this mirrors the sibling guard in
|
||||
# ``recover_with_credential_pool`` (see above) and correctly
|
||||
# disambiguates multiple custom providers that share one gateway
|
||||
# base_url. Fixes #56885.
|
||||
from agent.credential_pool import CUSTOM_POOL_PREFIX
|
||||
if (
|
||||
primary_provider == "custom"
|
||||
and entry_provider.startswith(CUSTOM_POOL_PREFIX)
|
||||
):
|
||||
entry_matches_primary = False
|
||||
try:
|
||||
from agent.credential_pool import get_custom_provider_pool_key
|
||||
primary_base_url = str(rt.get("base_url") or "").strip()
|
||||
primary_key = (
|
||||
get_custom_provider_pool_key(primary_base_url) or ""
|
||||
).strip().lower()
|
||||
entry_matches_primary = bool(primary_key) and primary_key == entry_provider
|
||||
except Exception:
|
||||
entry_matches_primary = False
|
||||
|
||||
entry_key = (
|
||||
getattr(entry, "runtime_api_key", None)
|
||||
|
|
|
|||
|
|
@ -241,6 +241,139 @@ class TestRestorePrimaryRuntime:
|
|||
assert agent.base_url == original_base_url
|
||||
agent._swap_credential.assert_not_called()
|
||||
|
||||
def test_restore_keeps_primary_base_url_when_fallback_pool_attached(self):
|
||||
"""Issue #56885: plain-provider primary must not inherit a fallback
|
||||
provider's base_url via the restore-path pool reselect.
|
||||
|
||||
Repro: primary is openai-api/gpt-5.5, a transient failure falls back to
|
||||
deepseek and attaches deepseek's credential pool. On the next turn the
|
||||
restore reselect must NOT swap in the deepseek entry — otherwise the
|
||||
request goes out as model=gpt-5.5 to base_url=api.deepseek.com → 404.
|
||||
"""
|
||||
|
||||
class _DeepseekEntry:
|
||||
provider = "deepseek"
|
||||
id = "dsk-1"
|
||||
label = "deepseek-key"
|
||||
runtime_api_key = "sk-deepseek-xxx"
|
||||
runtime_base_url = "https://api.deepseek.com/v1"
|
||||
base_url = "https://api.deepseek.com/v1"
|
||||
access_token = "sk-deepseek-xxx"
|
||||
|
||||
class _DeepseekPool:
|
||||
provider = "deepseek"
|
||||
|
||||
def has_available(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _DeepseekEntry()
|
||||
|
||||
agent = _make_agent(
|
||||
provider="openai-api",
|
||||
base_url="https://api.openai.com/v1",
|
||||
fallback_model={"provider": "deepseek", "model": "deepseek-v4-flash"},
|
||||
)
|
||||
primary_base_url = agent.base_url
|
||||
primary_provider = agent.provider
|
||||
mock_client = _mock_resolve(base_url="https://api.deepseek.com/v1")
|
||||
with patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, None),
|
||||
):
|
||||
agent._try_activate_fallback()
|
||||
# Fallback attached deepseek's pool; simulate it surviving into the next turn.
|
||||
agent._credential_pool = _DeepseekPool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
with patch("run_agent.OpenAI", return_value=MagicMock()):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is True
|
||||
assert agent.provider == primary_provider
|
||||
assert agent.base_url == primary_base_url
|
||||
assert "deepseek" not in str(agent.base_url)
|
||||
agent._swap_credential.assert_not_called()
|
||||
|
||||
def test_restore_swaps_matching_custom_pool_entry(self):
|
||||
"""Custom primary + custom:<name> entry whose base_url resolves to the
|
||||
SAME custom key must swap (legitimate same-endpoint rotation)."""
|
||||
|
||||
class _Entry:
|
||||
provider = "custom:myllm"
|
||||
id = "custom-entry"
|
||||
label = "myllm"
|
||||
runtime_api_key = "custom-key"
|
||||
runtime_base_url = "https://my-llm.example.com/v1"
|
||||
access_token = "custom-key"
|
||||
|
||||
class _Pool:
|
||||
provider = "custom:myllm"
|
||||
|
||||
def has_available(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
agent = _make_agent(provider="custom", base_url="https://my-llm.example.com/v1")
|
||||
agent._fallback_activated = True
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent.credential_pool.get_custom_provider_pool_key",
|
||||
return_value="custom:myllm",
|
||||
),
|
||||
patch("run_agent.OpenAI", return_value=MagicMock()),
|
||||
):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is True
|
||||
agent._swap_credential.assert_called_once()
|
||||
|
||||
def test_restore_skips_cross_endpoint_custom_pool_entry(self):
|
||||
"""Custom primary + custom:<name> entry whose base_url resolves to a
|
||||
DIFFERENT custom key must skip — two named custom providers sharing a
|
||||
gateway must not cross-contaminate."""
|
||||
|
||||
class _Entry:
|
||||
provider = "custom:otherllm"
|
||||
id = "other-entry"
|
||||
label = "otherllm"
|
||||
runtime_api_key = "other-key"
|
||||
runtime_base_url = "https://my-llm.example.com/v1"
|
||||
access_token = "other-key"
|
||||
|
||||
class _Pool:
|
||||
provider = "custom:otherllm"
|
||||
|
||||
def has_available(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
agent = _make_agent(provider="custom", base_url="https://my-llm.example.com/v1")
|
||||
agent._fallback_activated = True
|
||||
original_base_url = agent.base_url
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"agent.credential_pool.get_custom_provider_pool_key",
|
||||
return_value="custom:myllm", # primary resolves to a DIFFERENT key
|
||||
),
|
||||
patch("run_agent.OpenAI", return_value=MagicMock()),
|
||||
):
|
||||
result = agent._restore_primary_runtime()
|
||||
|
||||
assert result is True
|
||||
assert agent.base_url == original_base_url
|
||||
agent._swap_credential.assert_not_called()
|
||||
|
||||
def test_restore_survives_exception(self):
|
||||
"""If client rebuild fails, the method returns False gracefully."""
|
||||
agent = _make_agent()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue