diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index c5fd9a20a..4c8877232 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -1680,26 +1680,48 @@ def _read_main_provider() -> str: # per turn — no lock needed. Cleared by ``clear_runtime_main()``. _RUNTIME_MAIN_PROVIDER: str = "" _RUNTIME_MAIN_MODEL: str = "" +_RUNTIME_MAIN_BASE_URL: str = "" +_RUNTIME_MAIN_API_KEY: str = "" +_RUNTIME_MAIN_API_MODE: str = "" -def set_runtime_main(provider: str, model: str) -> None: - """Record the live runtime provider/model for the current AIAgent. +def set_runtime_main( + provider: str, + model: str, + *, + base_url: str = "", + api_key: str = "", + api_mode: str = "", +) -> None: + """Record the live runtime provider/model/credentials for the current AIAgent. Called by ``run_agent.AIAgent._sync_runtime_main_for_aux_routing`` (or equivalent setter) at the top of each turn so that ``_read_main_provider`` / ``_read_main_model`` reflect CLI/gateway overrides instead of the stale config.yaml default. + + For ``custom:`` providers, ``base_url`` and ``api_key`` must also be + recorded so that ``_resolve_auto`` can construct a valid client in + Step 1 instead of falling through to the aggregator chain. """ global _RUNTIME_MAIN_PROVIDER, _RUNTIME_MAIN_MODEL + global _RUNTIME_MAIN_BASE_URL, _RUNTIME_MAIN_API_KEY, _RUNTIME_MAIN_API_MODE _RUNTIME_MAIN_PROVIDER = (provider or "").strip().lower() _RUNTIME_MAIN_MODEL = (model or "").strip() + _RUNTIME_MAIN_BASE_URL = (base_url or "").strip() + _RUNTIME_MAIN_API_KEY = api_key.strip() if isinstance(api_key, str) else "" + _RUNTIME_MAIN_API_MODE = (api_mode or "").strip() def clear_runtime_main() -> None: """Clear the runtime override (e.g. on session end).""" global _RUNTIME_MAIN_PROVIDER, _RUNTIME_MAIN_MODEL + global _RUNTIME_MAIN_BASE_URL, _RUNTIME_MAIN_API_KEY, _RUNTIME_MAIN_API_MODE _RUNTIME_MAIN_PROVIDER = "" _RUNTIME_MAIN_MODEL = "" + _RUNTIME_MAIN_BASE_URL = "" + _RUNTIME_MAIN_API_KEY = "" + _RUNTIME_MAIN_API_MODE = "" def _resolve_custom_runtime() -> Tuple[Optional[str], Optional[str], Optional[str]]: @@ -2980,6 +3002,18 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option runtime_api_key = runtime.get("api_key", "") runtime_api_mode = str(runtime.get("api_mode") or "") + # Fall back to process-local globals when main_runtime dict was not + # provided or was incomplete. ``set_runtime_main()`` now records + # base_url/api_key/api_mode alongside provider/model, so custom: + # providers get the full credential surface in Step 1 of the + # auto-detect chain. + if not runtime_base_url and _RUNTIME_MAIN_BASE_URL: + runtime_base_url = _RUNTIME_MAIN_BASE_URL + if not runtime_api_key and _RUNTIME_MAIN_API_KEY: + runtime_api_key = _RUNTIME_MAIN_API_KEY + if not runtime_api_mode and _RUNTIME_MAIN_API_MODE: + runtime_api_mode = _RUNTIME_MAIN_API_MODE + # ── Warn once if OPENAI_BASE_URL is set but config.yaml uses a named # provider (not 'custom'). This catches the common "env poisoning" # scenario where a user switches providers via `hermes model` but the diff --git a/agent/conversation_loop.py b/agent/conversation_loop.py index cf77d9a1b..21199b9a2 100644 --- a/agent/conversation_loop.py +++ b/agent/conversation_loop.py @@ -392,6 +392,9 @@ def run_conversation( set_runtime_main( getattr(agent, "provider", "") or "", getattr(agent, "model", "") or "", + base_url=getattr(agent, "base_url", "") or "", + api_key=getattr(agent, "api_key", "") or "", + api_mode=getattr(agent, "api_mode", "") or "", ) except Exception: pass diff --git a/tests/agent/test_set_runtime_main_custom_provider.py b/tests/agent/test_set_runtime_main_custom_provider.py new file mode 100644 index 000000000..067cebdc4 --- /dev/null +++ b/tests/agent/test_set_runtime_main_custom_provider.py @@ -0,0 +1,129 @@ +"""Regression test: set_runtime_main() must pass base_url/api_key/api_mode +so that _resolve_auto() can route custom: providers in Step 1. + +Fixes https://github.com/NousResearch/hermes-agent/issues/34777 +""" +import pytest +from unittest.mock import patch, MagicMock + + +def _get_globals(mod): + """Read runtime globals without triggering redaction.""" + return { + "provider": mod._RUNTIME_MAIN_PROVIDER, + "model": mod._RUNTIME_MAIN_MODEL, + "base_url": mod._RUNTIME_MAIN_BASE_URL, + "cred": mod._RUNTIME_MAIN_API_KEY, # renamed to avoid redaction + "api_mode": mod._RUNTIME_MAIN_API_MODE, + } + + +class TestSetRuntimeMainCustomProvider: + """set_runtime_main must propagate base_url/api_key/api_mode for custom providers.""" + + def test_globals_stored(self): + """set_runtime_main stores all five fields in process-local globals.""" + import agent.auxiliary_client as mod + + mod.clear_runtime_main() + try: + mod.set_runtime_main( + "custom:my-router", + "glm-5.1", + base_url="https://my-server.example.com/v1", + api_key="sk-test-key", + api_mode="chat_completions", + ) + g = _get_globals(mod) + assert g["provider"] == "custom:my-router" + assert g["model"] == "glm-5.1" + assert g["base_url"] == "https://my-server.example.com/v1" + assert g["cred"] == "sk-test-key" + assert g["api_mode"] == "chat_completions" + finally: + mod.clear_runtime_main() + + def test_clear_resets_all_globals(self): + """clear_runtime_main resets all five globals to empty.""" + import agent.auxiliary_client as mod + + mod.set_runtime_main( + "custom:x", "m", + base_url="https://x.example.com", + api_key="sk-abc", + api_mode="chat_completions", + ) + mod.clear_runtime_main() + g = _get_globals(mod) + for v in g.values(): + assert v == "", f"Expected empty, got {v!r}" + + def test_resolve_auto_uses_globals_for_custom_provider(self): + """_resolve_auto reads base_url/api_key from globals when main_runtime is None.""" + import agent.auxiliary_client as mod + + mod.clear_runtime_main() + try: + mod.set_runtime_main( + "custom:test-router", + "test-model", + base_url="https://custom-endpoint.example.com/v1", + api_key="sk-test-123", + ) + + with patch.object(mod, "resolve_provider_client") as mock_resolve: + mock_resolve.return_value = (MagicMock(), "test-model") + client, resolved = mod._resolve_auto(main_runtime=None) + + mock_resolve.assert_called_once() + call_args = mock_resolve.call_args + assert call_args[0][0] == "custom" + assert call_args[1]["explicit_base_url"] == "https://custom-endpoint.example.com/v1" + assert call_args[1]["explicit_api_key"] == "sk-test-123" + finally: + mod.clear_runtime_main() + + def test_explicit_main_runtime_takes_precedence(self): + """When main_runtime dict has values, globals are NOT used.""" + import agent.auxiliary_client as mod + + mod.clear_runtime_main() + try: + mod.set_runtime_main( + "custom:router-a", + "model-a", + base_url="https://from-global.example.com", + api_key="sk-global", + ) + + with patch.object(mod, "resolve_provider_client") as mock_resolve: + mock_resolve.return_value = (MagicMock(), "model-b") + main_rt = { + "provider": "custom:router-b", + "model": "model-b", + "base_url": "https://from-dict.example.com", + "api_key": "sk-dict", + } + mod._resolve_auto(main_runtime=main_rt) + + call_args = mock_resolve.call_args[1] + assert call_args["explicit_base_url"] == "https://from-dict.example.com" + assert call_args["explicit_api_key"] == "sk-dict" + finally: + mod.clear_runtime_main() + + def test_backward_compatible_defaults(self): + """Calling set_runtime_main with only positional args still works.""" + import agent.auxiliary_client as mod + + mod.clear_runtime_main() + try: + mod.set_runtime_main("openrouter", "gpt-4o") + g = _get_globals(mod) + assert g["provider"] == "openrouter" + assert g["model"] == "gpt-4o" + assert g["base_url"] == "" + assert g["cred"] == "" + assert g["api_mode"] == "" + finally: + mod.clear_runtime_main()