fix(agent): honor custom CA certs for custom_providers HTTPS endpoints
Wire ssl_ca_cert and ssl_verify through custom_providers config and env vars into the keepalive httpx client, fixing APIConnectionError against mkcert/self-signed Ollama proxies behind HTTPS.
This commit is contained in:
parent
7e957cbd0b
commit
3a2ba959ce
6 changed files with 199 additions and 6 deletions
|
|
@ -974,6 +974,21 @@ def init_agent(
|
|||
# this mutation is reflected in the client built just below.
|
||||
agent._apply_user_default_headers()
|
||||
|
||||
try:
|
||||
from hermes_cli.config import (
|
||||
apply_custom_provider_tls_to_client_kwargs,
|
||||
get_compatible_custom_providers,
|
||||
load_config,
|
||||
)
|
||||
|
||||
apply_custom_provider_tls_to_client_kwargs(
|
||||
client_kwargs,
|
||||
str(client_kwargs.get("base_url") or agent.base_url or ""),
|
||||
get_compatible_custom_providers(load_config()),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
agent.api_key = client_kwargs.get("api_key", "")
|
||||
agent.base_url = client_kwargs.get("base_url", agent.base_url)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1513,6 +1513,7 @@ def anthropic_prompt_cache_policy(
|
|||
|
||||
def create_openai_client(agent, client_kwargs: dict, *, reason: str, shared: bool) -> Any:
|
||||
from agent.auxiliary_client import _validate_base_url, _validate_proxy_env_urls
|
||||
from agent.ssl_verify import resolve_httpx_verify
|
||||
# Treat client_kwargs as read-only. Callers pass agent._client_kwargs (or shallow
|
||||
# copies of it) in; any in-place mutation leaks back into the stored dict and is
|
||||
# reused on subsequent requests. #10933 hit this by injecting an httpx.Client
|
||||
|
|
@ -1522,6 +1523,9 @@ def create_openai_client(agent, client_kwargs: dict, *, reason: str, shared: boo
|
|||
# copy locks the contract so future transport/keepalive work can't reintroduce
|
||||
# the same class of bug.
|
||||
client_kwargs = dict(client_kwargs)
|
||||
ssl_ca_cert = client_kwargs.pop("ssl_ca_cert", None)
|
||||
ssl_verify_cfg = client_kwargs.pop("ssl_verify", None)
|
||||
httpx_verify = resolve_httpx_verify(ca_bundle=ssl_ca_cert, ssl_verify=ssl_verify_cfg)
|
||||
_validate_proxy_env_urls()
|
||||
_validate_base_url(client_kwargs.get("base_url"))
|
||||
if agent.provider == "copilot-acp" or str(client_kwargs.get("base_url", "")).startswith("acp://copilot"):
|
||||
|
|
@ -1545,7 +1549,9 @@ def create_openai_client(agent, client_kwargs: dict, *, reason: str, shared: boo
|
|||
if k in {"api_key", "base_url", "default_headers", "timeout", "http_client"}
|
||||
}
|
||||
if "http_client" not in safe_kwargs:
|
||||
keepalive_http = agent._build_keepalive_http_client(base_url)
|
||||
keepalive_http = agent._build_keepalive_http_client(
|
||||
base_url, verify=httpx_verify,
|
||||
)
|
||||
if keepalive_http is not None:
|
||||
safe_kwargs["http_client"] = keepalive_http
|
||||
client = GeminiNativeClient(**safe_kwargs)
|
||||
|
|
@ -1574,7 +1580,9 @@ def create_openai_client(agent, client_kwargs: dict, *, reason: str, shared: boo
|
|||
# Tests in ``tests/run_agent/test_create_openai_client_reuse.py`` and
|
||||
# ``tests/run_agent/test_sequential_chats_live.py`` pin this invariant.
|
||||
if "http_client" not in client_kwargs:
|
||||
keepalive_http = agent._build_keepalive_http_client(client_kwargs.get("base_url", ""))
|
||||
keepalive_http = agent._build_keepalive_http_client(
|
||||
client_kwargs.get("base_url", ""), verify=httpx_verify,
|
||||
)
|
||||
if keepalive_http is not None:
|
||||
client_kwargs["http_client"] = keepalive_http
|
||||
# Delegate all rate-limit / 5xx retry to hermes's outer conversation loop,
|
||||
|
|
@ -1778,6 +1786,16 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo
|
|||
"api_key": effective_key,
|
||||
"base_url": effective_base,
|
||||
}
|
||||
try:
|
||||
from hermes_cli.config import apply_custom_provider_tls_to_client_kwargs
|
||||
|
||||
apply_custom_provider_tls_to_client_kwargs(
|
||||
agent._client_kwargs,
|
||||
str(effective_base or ""),
|
||||
getattr(agent, "_custom_providers", None),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
_sm_timeout = get_provider_request_timeout(agent.provider, agent.model)
|
||||
if _sm_timeout is not None:
|
||||
agent._client_kwargs["timeout"] = _sm_timeout
|
||||
|
|
|
|||
|
|
@ -4478,7 +4478,7 @@ def _normalize_custom_provider_entry(
|
|||
"api_mode", "transport", "model", "default_model", "models",
|
||||
"context_length", "rate_limit_delay",
|
||||
"request_timeout_seconds", "stale_timeout_seconds",
|
||||
"discover_models", "extra_body",
|
||||
"discover_models", "extra_body", "ssl_ca_cert", "ssl_verify",
|
||||
}
|
||||
for camel, snake in _CAMEL_ALIASES.items():
|
||||
if camel in entry and snake not in entry:
|
||||
|
|
@ -4585,6 +4585,16 @@ def _normalize_custom_provider_entry(
|
|||
if isinstance(extra_body, dict):
|
||||
normalized["extra_body"] = dict(extra_body)
|
||||
|
||||
ssl_ca_cert = entry.get("ssl_ca_cert")
|
||||
if isinstance(ssl_ca_cert, str) and ssl_ca_cert.strip():
|
||||
normalized["ssl_ca_cert"] = ssl_ca_cert.strip()
|
||||
|
||||
ssl_verify = entry.get("ssl_verify")
|
||||
if isinstance(ssl_verify, bool):
|
||||
normalized["ssl_verify"] = ssl_verify
|
||||
elif isinstance(ssl_verify, str) and ssl_verify.strip():
|
||||
normalized["ssl_verify"] = ssl_verify.strip()
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
|
|
@ -4612,6 +4622,8 @@ def _custom_provider_entry_to_provider_config(
|
|||
"rate_limit_delay",
|
||||
"discover_models",
|
||||
"extra_body",
|
||||
"ssl_ca_cert",
|
||||
"ssl_verify",
|
||||
):
|
||||
if field in normalized:
|
||||
provider_entry[field] = normalized[field]
|
||||
|
|
@ -4688,6 +4700,66 @@ def get_compatible_custom_providers(
|
|||
return compatible
|
||||
|
||||
|
||||
def _coerce_ssl_verify(value: Any) -> Optional[bool]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
lowered = value.strip().lower()
|
||||
if lowered in {"false", "0", "no", "off"}:
|
||||
return False
|
||||
if lowered in {"true", "1", "yes", "on"}:
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
def get_custom_provider_tls_settings(
|
||||
base_url: str,
|
||||
custom_providers: Optional[List[Dict[str, Any]]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return TLS settings from a matching ``custom_providers`` / ``providers`` entry."""
|
||||
if custom_providers is None:
|
||||
try:
|
||||
custom_providers = get_compatible_custom_providers(config)
|
||||
except Exception:
|
||||
custom_providers = []
|
||||
if not base_url or not isinstance(custom_providers, list):
|
||||
return {}
|
||||
|
||||
target_url = (base_url or "").rstrip("/")
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
entry_url = (entry.get("base_url") or "").rstrip("/")
|
||||
if not entry_url or entry_url != target_url:
|
||||
continue
|
||||
out: Dict[str, Any] = {}
|
||||
ca = entry.get("ssl_ca_cert")
|
||||
if isinstance(ca, str) and ca.strip():
|
||||
out["ssl_ca_cert"] = ca.strip()
|
||||
verify = _coerce_ssl_verify(entry.get("ssl_verify"))
|
||||
if verify is not None:
|
||||
out["ssl_verify"] = verify
|
||||
return out
|
||||
return {}
|
||||
|
||||
|
||||
def apply_custom_provider_tls_to_client_kwargs(
|
||||
client_kwargs: Dict[str, Any],
|
||||
base_url: str,
|
||||
custom_providers: Optional[List[Dict[str, Any]]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Attach per-provider TLS knobs to OpenAI client kwargs when matched."""
|
||||
tls = get_custom_provider_tls_settings(base_url, custom_providers, config)
|
||||
if tls.get("ssl_ca_cert"):
|
||||
client_kwargs["ssl_ca_cert"] = tls["ssl_ca_cert"]
|
||||
if "ssl_verify" in tls:
|
||||
client_kwargs["ssl_verify"] = tls["ssl_verify"]
|
||||
|
||||
|
||||
def get_custom_provider_context_length(
|
||||
model: str,
|
||||
base_url: str,
|
||||
|
|
@ -4813,6 +4885,7 @@ _KNOWN_ROOT_KEYS = {
|
|||
_VALID_CUSTOM_PROVIDER_FIELDS = {
|
||||
"name", "base_url", "api_key", "api_mode", "model", "models",
|
||||
"context_length", "rate_limit_delay", "extra_body",
|
||||
"ssl_ca_cert", "ssl_verify",
|
||||
# key_env is read at runtime by runtime_provider.py and auxiliary_client.py
|
||||
# — include it here so the set accurately describes the supported schema.
|
||||
"key_env",
|
||||
|
|
|
|||
|
|
@ -3884,13 +3884,13 @@ class AIAgent:
|
|||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_keepalive_http_client(base_url: str = "") -> Any:
|
||||
def _build_keepalive_http_client(base_url: str = "", *, verify: Any = True) -> Any:
|
||||
try:
|
||||
import httpx as _httpx
|
||||
import socket as _socket
|
||||
|
||||
if "api.githubcopilot.com" in str(base_url or "").lower():
|
||||
return _httpx.Client()
|
||||
return _httpx.Client(verify=verify)
|
||||
|
||||
_sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)]
|
||||
if hasattr(_socket, "TCP_KEEPIDLE"):
|
||||
|
|
@ -3905,8 +3905,9 @@ class AIAgent:
|
|||
# loopback / local endpoints such as a locally hosted sub2api.
|
||||
_proxy = _get_proxy_for_base_url(base_url)
|
||||
return _httpx.Client(
|
||||
transport=_httpx.HTTPTransport(socket_options=_sock_opts),
|
||||
transport=_httpx.HTTPTransport(socket_options=_sock_opts, verify=verify),
|
||||
proxy=_proxy,
|
||||
verify=verify,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
|
|
|||
40
tests/hermes_cli/test_custom_provider_tls.py
Normal file
40
tests/hermes_cli/test_custom_provider_tls.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Tests for per-provider TLS settings in custom_providers config."""
|
||||
|
||||
from hermes_cli.config import (
|
||||
apply_custom_provider_tls_to_client_kwargs,
|
||||
get_custom_provider_tls_settings,
|
||||
)
|
||||
|
||||
|
||||
def test_get_custom_provider_tls_settings_matches_base_url():
|
||||
providers = [
|
||||
{
|
||||
"name": "Ollama",
|
||||
"base_url": "https://ollama.example.com/v1",
|
||||
"ssl_ca_cert": "/etc/ssl/mkcert-root.pem",
|
||||
}
|
||||
]
|
||||
tls = get_custom_provider_tls_settings(
|
||||
"https://ollama.example.com/v1/",
|
||||
custom_providers=providers,
|
||||
)
|
||||
assert tls == {"ssl_ca_cert": "/etc/ssl/mkcert-root.pem"}
|
||||
|
||||
|
||||
def test_apply_custom_provider_tls_to_client_kwargs():
|
||||
client_kwargs = {"api_key": "x", "base_url": "https://ollama.example.com/v1"}
|
||||
providers = [
|
||||
{
|
||||
"name": "Ollama",
|
||||
"base_url": "https://ollama.example.com/v1",
|
||||
"ssl_ca_cert": "/etc/ssl/mkcert-root.pem",
|
||||
"ssl_verify": True,
|
||||
}
|
||||
]
|
||||
apply_custom_provider_tls_to_client_kwargs(
|
||||
client_kwargs,
|
||||
"https://ollama.example.com/v1",
|
||||
custom_providers=providers,
|
||||
)
|
||||
assert client_kwargs["ssl_ca_cert"] == "/etc/ssl/mkcert-root.pem"
|
||||
assert client_kwargs["ssl_verify"] is True
|
||||
46
tests/run_agent/test_create_openai_client_ssl_verify.py
Normal file
46
tests/run_agent/test_create_openai_client_ssl_verify.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Regression: keepalive httpx client must honor custom CA bundles for HTTPS providers."""
|
||||
|
||||
import ssl
|
||||
|
||||
import certifi
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from agent.ssl_verify import resolve_httpx_verify
|
||||
from run_agent import AIAgent
|
||||
|
||||
_CA_ENV_VARS = ("HERMES_CA_BUNDLE", "SSL_CERT_FILE", "REQUESTS_CA_BUNDLE", "HTTPS_PROXY")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_tls_env(monkeypatch):
|
||||
for var in _CA_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
def test_build_keepalive_http_client_uses_hermes_ca_bundle(clean_tls_env, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_CA_BUNDLE", certifi.where())
|
||||
verify = resolve_httpx_verify()
|
||||
client = AIAgent._build_keepalive_http_client(
|
||||
"https://ollama.example.com/v1", verify=verify,
|
||||
)
|
||||
assert isinstance(client, httpx.Client)
|
||||
assert isinstance(client._transport._pool._ssl_context, ssl.SSLContext)
|
||||
|
||||
|
||||
def test_build_keepalive_http_client_honors_per_provider_ssl_ca_cert(clean_tls_env):
|
||||
verify = resolve_httpx_verify(ca_bundle=certifi.where())
|
||||
client = AIAgent._build_keepalive_http_client(
|
||||
"https://ollama.example.com/v1", verify=verify,
|
||||
)
|
||||
assert isinstance(client, httpx.Client)
|
||||
assert isinstance(client._transport._pool._ssl_context, ssl.SSLContext)
|
||||
|
||||
|
||||
def test_build_keepalive_http_client_ssl_verify_false(clean_tls_env):
|
||||
verify = resolve_httpx_verify(ssl_verify=False)
|
||||
client = AIAgent._build_keepalive_http_client(
|
||||
"https://ollama.example.com/v1", verify=verify,
|
||||
)
|
||||
assert isinstance(client, httpx.Client)
|
||||
assert client._transport._pool._ssl_context.check_hostname is False
|
||||
Loading…
Add table
Add a link
Reference in a new issue