fix(gateway): preserve queued native image attachments
This commit is contained in:
parent
e880396488
commit
bb24ac6f20
2 changed files with 161 additions and 1 deletions
|
|
@ -19063,6 +19063,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
next_message = pending
|
||||
next_message_id = None
|
||||
next_channel_prompt = None
|
||||
next_session_key = session_key
|
||||
if pending_event is not None:
|
||||
next_source = getattr(pending_event, "source", None) or source
|
||||
if self._is_goal_continuation_event(pending_event) and not self._goal_still_active_for_session(session_id):
|
||||
|
|
@ -19081,6 +19082,14 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
return result
|
||||
next_message_id = self._reply_anchor_for_event(pending_event)
|
||||
next_channel_prompt = getattr(pending_event, "channel_prompt", None)
|
||||
try:
|
||||
next_session_key = self._session_key_for_source(next_source)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Queued follow-up session-key resolution failed; reusing %s",
|
||||
session_key or "?",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Restart typing indicator so the user sees activity while
|
||||
# the follow-up turn runs. The outer _process_message_background
|
||||
|
|
@ -19117,7 +19126,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
|
|||
history=updated_history,
|
||||
source=next_source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
session_key=next_session_key,
|
||||
run_generation=run_generation,
|
||||
_interrupt_depth=_interrupt_depth + 1,
|
||||
event_message_id=next_message_id,
|
||||
|
|
|
|||
151
tests/gateway/test_queued_native_image_session_key.py
Normal file
151
tests/gateway/test_queued_native_image_session_key.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import base64
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
_ONE_BY_ONE_PNG = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO6L2ioAAAAASUVORK5CYII="
|
||||
)
|
||||
|
||||
|
||||
class CaptureAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
self.sent = []
|
||||
self.typing = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
|
||||
self.sent.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"content": content,
|
||||
"reply_to": reply_to,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return SendResult(success=True, message_id="sent-1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||
|
||||
async def stop_typing(self, chat_id) -> None:
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class CaptureQueuedNativeImageAgent:
|
||||
calls = []
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.tools = []
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
type(self).calls.append(message)
|
||||
return {
|
||||
"final_response": f"done-{len(type(self).calls)}",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
runner = object.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {adapter.platform: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner._prefill_messages = []
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
group_sessions_per_user=False,
|
||||
stt_enabled=False,
|
||||
)
|
||||
runner._model = "openai/gpt-4.1-mini"
|
||||
runner._base_url = None
|
||||
runner._decide_image_input_mode = lambda: "native"
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queued_followup_uses_pending_event_session_key_for_native_images(monkeypatch, tmp_path):
|
||||
CaptureQueuedNativeImageAgent.calls = []
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = CaptureQueuedNativeImageAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
adapter = CaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
|
||||
image_path = tmp_path / "queued-image.png"
|
||||
image_path.write_bytes(_ONE_BY_ONE_PNG)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
)
|
||||
pending_source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
|
||||
adapter._pending_messages["agent:main:telegram:group:-1001"] = MessageEvent(
|
||||
text="describe this",
|
||||
message_type=MessageType.PHOTO,
|
||||
source=pending_source,
|
||||
media_urls=[str(image_path)],
|
||||
media_types=["image/png"],
|
||||
message_id="queued-1",
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-native-image-followup",
|
||||
session_key="agent:main:telegram:group:-1001",
|
||||
)
|
||||
|
||||
assert result["final_response"] == "done-2"
|
||||
assert len(CaptureQueuedNativeImageAgent.calls) == 2
|
||||
queued_message = CaptureQueuedNativeImageAgent.calls[1]
|
||||
assert isinstance(queued_message, list)
|
||||
assert queued_message[0]["type"] == "text"
|
||||
assert queued_message[0]["text"].startswith("describe this")
|
||||
assert any(part.get("type") == "image_url" for part in queued_message)
|
||||
Loading…
Add table
Add a link
Reference in a new issue