diff --git a/gateway/platforms/yuanbao_media.py b/gateway/platforms/yuanbao_media.py index 87eefcdda..85abb3049 100644 --- a/gateway/platforms/yuanbao_media.py +++ b/gateway/platforms/yuanbao_media.py @@ -217,8 +217,28 @@ async def download_url( ValueError: 内容超过大小限制 httpx.HTTPError: 网络/HTTP 错误 """ + # SSRF protection: yuanbao downloads model-supplied and inbound URLs + # server-side. Reject private/internal targets up front, and re-validate + # every redirect hop so a public URL can't 302 to http://169.254.169.254/. + from tools.url_safety import is_safe_url + + if not is_safe_url(url): + raise ValueError(f"Blocked unsafe URL (SSRF protection): {url}") + + async def _redirect_guard(response: httpx.Response) -> None: + if response.is_redirect and response.next_request: + redirect_url = str(response.next_request.url) + if not is_safe_url(redirect_url): + raise ValueError( + f"Blocked redirect to private/internal address: {redirect_url}" + ) + max_bytes = max_size_mb * 1024 * 1024 - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + event_hooks={"response": [_redirect_guard]}, + ) as client: # 先 HEAD 检查大小 try: head = await client.head(url) diff --git a/tests/gateway/test_yuanbao_media_ssrf.py b/tests/gateway/test_yuanbao_media_ssrf.py new file mode 100644 index 000000000..329a18e7c --- /dev/null +++ b/tests/gateway/test_yuanbao_media_ssrf.py @@ -0,0 +1,92 @@ +"""SSRF protection tests for yuanbao_media.download_url(). + +download_url() fetches both model-supplied (outbound) and inbound image/file +URLs server-side via httpx. Without an is_safe_url() pre-flight, a model +response (or inbound message) containing http://169.254.169.254/... would make +the gateway fetch cloud-metadata endpoints. These tests pin the guard. +""" + +import pytest + +from gateway.platforms.yuanbao_media import download_url + + +class TestDownloadUrlSSRF: + @pytest.mark.asyncio + async def test_metadata_endpoint_blocked(self): + with pytest.raises(ValueError, match="SSRF protection"): + await download_url("http://169.254.169.254/latest/meta-data/") + + @pytest.mark.asyncio + async def test_loopback_blocked(self): + with pytest.raises(ValueError, match="SSRF protection"): + await download_url("http://127.0.0.1:8080/secret") + + @pytest.mark.asyncio + async def test_private_range_blocked(self): + with pytest.raises(ValueError, match="SSRF protection"): + await download_url("http://192.168.1.1/admin/logo.png") + + @pytest.mark.asyncio + async def test_non_http_scheme_blocked(self): + with pytest.raises(ValueError, match="SSRF protection"): + await download_url("file:///etc/passwd") + + @pytest.mark.asyncio + async def test_public_url_passes_guard_then_fetches(self, monkeypatch): + """A public URL clears the SSRF guard and reaches the HTTP client. + + We stub is_safe_url True and the httpx client so no real network call + happens — the assertion is that the guard does not reject a public URL. + """ + import gateway.platforms.yuanbao_media as ym + + fetched = {} + + class _FakeResp: + headers = {"content-type": "image/png", "content-length": "3"} + is_redirect = False + next_request = None + + def raise_for_status(self): + pass + + async def aiter_bytes(self, _n): + yield b"png" + + class _FakeStream: + async def __aenter__(self): + return _FakeResp() + + async def __aexit__(self, *a): + return False + + class _FakeClient: + def __init__(self, *a, **kw): + fetched["hooks"] = kw.get("event_hooks") + + async def __aenter__(self): + return self + + async def __aexit__(self, *a): + return False + + async def head(self, url): + return _FakeResp() + + def stream(self, method, url, **kw): + fetched["url"] = url + return _FakeStream() + + monkeypatch.setattr(ym, "is_safe_url", lambda u: True, raising=False) + # is_safe_url is imported inside the function, so patch the source too + from tools import url_safety + monkeypatch.setattr(url_safety, "is_safe_url", lambda u: True) + monkeypatch.setattr(ym.httpx, "AsyncClient", _FakeClient) + + data, ct = await download_url("https://example.com/image.png") + assert data == b"png" + assert ct == "image/png" + # The guarded client must register a redirect event hook. + assert fetched["hooks"] is not None + assert "response" in fetched["hooks"]