diff --git a/gateway/run.py b/gateway/run.py index 17ccd4c1c..07a1c93d7 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -19063,6 +19063,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew next_message = pending next_message_id = None next_channel_prompt = None + next_session_key = session_key if pending_event is not None: next_source = getattr(pending_event, "source", None) or source if self._is_goal_continuation_event(pending_event) and not self._goal_still_active_for_session(session_id): @@ -19081,6 +19082,14 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew return result next_message_id = self._reply_anchor_for_event(pending_event) next_channel_prompt = getattr(pending_event, "channel_prompt", None) + try: + next_session_key = self._session_key_for_source(next_source) + except Exception: + logger.debug( + "Queued follow-up session-key resolution failed; reusing %s", + session_key or "?", + exc_info=True, + ) # Restart typing indicator so the user sees activity while # the follow-up turn runs. The outer _process_message_background @@ -19117,7 +19126,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew history=updated_history, source=next_source, session_id=session_id, - session_key=session_key, + session_key=next_session_key, run_generation=run_generation, _interrupt_depth=_interrupt_depth + 1, event_message_id=next_message_id, diff --git a/tests/gateway/test_queued_native_image_session_key.py b/tests/gateway/test_queued_native_image_session_key.py new file mode 100644 index 000000000..86ba1620a --- /dev/null +++ b/tests/gateway/test_queued_native_image_session_key.py @@ -0,0 +1,151 @@ +import base64 +import importlib +import sys +import types +from types import SimpleNamespace + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult +from gateway.session import SessionSource + + +_ONE_BY_ONE_PNG = base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO6L2ioAAAAASUVORK5CYII=" +) + + +class CaptureAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM) + self.sent = [] + self.typing = [] + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + return None + + async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult: + self.sent.append( + { + "chat_id": chat_id, + "content": content, + "reply_to": reply_to, + "metadata": metadata, + } + ) + return SendResult(success=True, message_id="sent-1") + + async def send_typing(self, chat_id, metadata=None) -> None: + self.typing.append({"chat_id": chat_id, "metadata": metadata}) + + async def stop_typing(self, chat_id) -> None: + return None + + async def get_chat_info(self, chat_id: str): + return {"id": chat_id} + + +class CaptureQueuedNativeImageAgent: + calls = [] + + def __init__(self, **kwargs): + self.tools = [] + self.tool_progress_callback = kwargs.get("tool_progress_callback") + + def run_conversation(self, message, conversation_history=None, task_id=None): + type(self).calls.append(message) + return { + "final_response": f"done-{len(type(self).calls)}", + "messages": [], + "api_calls": 1, + } + + +def _make_runner(adapter): + gateway_run = importlib.import_module("gateway.run") + runner = object.__new__(gateway_run.GatewayRunner) + runner.adapters = {adapter.platform: adapter} + runner._voice_mode = {} + runner._prefill_messages = [] + runner._ephemeral_system_prompt = "" + runner._reasoning_config = None + runner._provider_routing = {} + runner._fallback_model = None + runner._session_db = None + runner._running_agents = {} + runner._session_run_generation = {} + runner.hooks = SimpleNamespace(loaded_hooks=False) + runner.config = SimpleNamespace( + thread_sessions_per_user=False, + group_sessions_per_user=False, + stt_enabled=False, + ) + runner._model = "openai/gpt-4.1-mini" + runner._base_url = None + runner._decide_image_input_mode = lambda: "native" + return runner + + +@pytest.mark.asyncio +async def test_queued_followup_uses_pending_event_session_key_for_native_images(monkeypatch, tmp_path): + CaptureQueuedNativeImageAgent.calls = [] + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = CaptureQueuedNativeImageAgent + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + gateway_run = importlib.import_module("gateway.run") + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + + adapter = CaptureAdapter() + runner = _make_runner(adapter) + + image_path = tmp_path / "queued-image.png" + image_path.write_bytes(_ONE_BY_ONE_PNG) + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_type="group", + ) + pending_source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_type="group", + thread_id="17585", + ) + + adapter._pending_messages["agent:main:telegram:group:-1001"] = MessageEvent( + text="describe this", + message_type=MessageType.PHOTO, + source=pending_source, + media_urls=[str(image_path)], + media_types=["image/png"], + message_id="queued-1", + ) + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="sess-native-image-followup", + session_key="agent:main:telegram:group:-1001", + ) + + assert result["final_response"] == "done-2" + assert len(CaptureQueuedNativeImageAgent.calls) == 2 + queued_message = CaptureQueuedNativeImageAgent.calls[1] + assert isinstance(queued_message, list) + assert queued_message[0]["type"] == "text" + assert queued_message[0]["text"].startswith("describe this") + assert any(part.get("type") == "image_url" for part in queued_message)