"""StateStream — per-job SSE event queue.""" from __future__ import annotations import asyncio import json import logging from collections.abc import AsyncGenerator logger = logging.getLogger(__name__) class StateStream: """Manages per-job asyncio.Queue for SSE events.""" def __init__(self) -> None: self._queues: dict[str, asyncio.Queue] = {} def create(self, job_id: str) -> None: self._queues[job_id] = asyncio.Queue() async def push(self, job_id: str, event: str, data: dict) -> None: q = self._queues.get(job_id) if q: await q.put((event, data)) async def push_done(self, job_id: str) -> None: """Push sentinel to signal stream end.""" q = self._queues.get(job_id) if q: await q.put(None) async def subscribe(self, job_id: str) -> AsyncGenerator[str, None]: """Yield SSE-formatted strings until done sentinel.""" q = self._queues.get(job_id) if not q: yield _sse("error", {"message": "Job not found"}) return while True: item = await q.get() if item is None: break event, data = item yield _sse(event, data) def cleanup(self, job_id: str) -> None: self._queues.pop(job_id, None) def _sse(event: str, data: dict) -> str: return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" state_stream = StateStream()