fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,821 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import threading
|
||||
import typing
|
||||
from concurrent.futures import Future
|
||||
|
||||
from ...._constant import DEFAULT_RETRIES
|
||||
from ....adapters import AsyncBaseAdapter, BaseAdapter
|
||||
from ....exceptions import ConnectTimeout, ReadTimeout
|
||||
from ....models import AsyncResponse, PreparedRequest, Response
|
||||
from ....packages.urllib3._async.response import AsyncHTTPResponse as BaseHTTPResponse
|
||||
from ....packages.urllib3.contrib.ssa._timeout import timeout as asyncio_timeout
|
||||
from ....packages.urllib3.exceptions import MaxRetryError
|
||||
from ....packages.urllib3.response import BytesQueueBuffer
|
||||
from ....packages.urllib3.util import Timeout as TimeoutSauce
|
||||
from ....packages.urllib3.util.retry import Retry
|
||||
from ....structures import CaseInsensitiveDict
|
||||
from ....utils import _swap_context
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ....typing import ASGIApp, ASGIMessage, ProxyType, RetryType, TLSClientCertType, TLSVerifyType
|
||||
|
||||
|
||||
class _ASGIRawIO:
|
||||
"""Async file-like wrapper around an ASGI response for true async streaming."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
self._response_queue = response_queue
|
||||
self._response_complete = response_complete
|
||||
self._timeout = timeout
|
||||
self._buffer = BytesQueueBuffer()
|
||||
self._closed = False
|
||||
self._finished = False
|
||||
self._task: asyncio.Task | None = None
|
||||
self.headers: dict | None = None
|
||||
self.extension: typing.Any = None
|
||||
|
||||
async def read(self, amt: int | None = None, decode_content: bool = True) -> bytes:
|
||||
if self._closed or self._finished:
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
if amt is None or amt < 0:
|
||||
async for chunk in self._async_iter_chunks():
|
||||
self._buffer.put(chunk)
|
||||
self._finished = True
|
||||
return self._buffer.get(len(self._buffer))
|
||||
|
||||
while len(self._buffer) < amt and not self._finished:
|
||||
chunk = await self._get_next_chunk() # type: ignore[assignment]
|
||||
if chunk is None:
|
||||
self._finished = True
|
||||
break
|
||||
self._buffer.put(chunk)
|
||||
|
||||
if len(self._buffer) == 0:
|
||||
return b""
|
||||
|
||||
return self._buffer.get(min(amt, len(self._buffer)))
|
||||
|
||||
async def _get_next_chunk(self) -> bytes | None:
|
||||
try:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
return None
|
||||
if message["type"] == "http.response.body":
|
||||
return message.get("body", b"")
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
await self._cancel_task()
|
||||
raise ReadTimeout("Read timed out while streaming ASGI response")
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
|
||||
async def _async_iter_chunks(self) -> typing.AsyncGenerator[bytes]:
|
||||
while True:
|
||||
try:
|
||||
async with asyncio_timeout(self._timeout):
|
||||
message = await self._response_queue.get()
|
||||
except asyncio.TimeoutError:
|
||||
await self._cancel_task()
|
||||
raise ReadTimeout("Read timed out while streaming ASGI response")
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.body":
|
||||
chunk = message.get("body", b"")
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def _cancel_task(self) -> None:
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
|
||||
def stream(self, amt: int, decode_content: bool = True) -> typing.AsyncGenerator[bytes]:
|
||||
return self._async_stream(amt)
|
||||
|
||||
async def _async_stream(self, amt: int) -> typing.AsyncGenerator[bytes]:
|
||||
while True:
|
||||
chunk = await self.read(amt)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
self._response_complete.set()
|
||||
|
||||
def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||||
return self._async_iter_self()
|
||||
|
||||
async def _async_iter_self(self) -> typing.AsyncIterator[bytes]:
|
||||
async for chunk in self._async_iter_chunks():
|
||||
yield chunk
|
||||
|
||||
async def __anext__(self) -> bytes:
|
||||
chunk = await self.read(8192)
|
||||
if not chunk:
|
||||
raise StopAsyncIteration
|
||||
return chunk
|
||||
|
||||
|
||||
class AsyncServerGatewayInterface(AsyncBaseAdapter):
|
||||
"""Adapter for making requests to ASGI applications directly."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
max_retries: RetryType = DEFAULT_RETRIES,
|
||||
lifespan_state: dict[str, typing.Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
self._lifespan_state = lifespan_state
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<ASGIAdapter Native/>"
|
||||
|
||||
async def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], typing.Awaitable[None]] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], typing.Awaitable[None]] | None = None,
|
||||
on_early_response: typing.Callable[[Response], typing.Awaitable[None]] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> AsyncResponse:
|
||||
"""Send a PreparedRequest to the ASGI application."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
retries = self.max_retries
|
||||
method = request.method or "GET"
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = await self._do_send(request, stream, timeout)
|
||||
except Exception as err:
|
||||
try:
|
||||
retries = retries.increment(method, request.url, error=err)
|
||||
except MaxRetryError:
|
||||
raise
|
||||
|
||||
await retries.async_sleep()
|
||||
continue
|
||||
|
||||
# we rely on the urllib3 implementation for retries
|
||||
# so we basically mock a response to get it to work
|
||||
base_response = BaseHTTPResponse(
|
||||
body=b"",
|
||||
headers=response.headers,
|
||||
status=response.status_code,
|
||||
request_method=request.method,
|
||||
request_url=request.url,
|
||||
)
|
||||
|
||||
# Check if we should retry based on status code
|
||||
has_retry_after = bool(response.headers.get("Retry-After"))
|
||||
|
||||
if retries.is_retry(method, response.status_code, has_retry_after):
|
||||
try:
|
||||
retries = retries.increment(method, request.url, response=base_response)
|
||||
except MaxRetryError:
|
||||
if retries.raise_on_status:
|
||||
raise
|
||||
return response
|
||||
|
||||
await retries.async_sleep(base_response)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
async def _do_send_ws(self, request: PreparedRequest) -> AsyncResponse:
|
||||
"""Handle WebSocket requests via the ASGI websocket protocol."""
|
||||
from ._ws import ASGIWebSocketExtension
|
||||
|
||||
scope = self._create_ws_scope(request)
|
||||
|
||||
ext = ASGIWebSocketExtension()
|
||||
await ext.start(self.app, scope)
|
||||
|
||||
response = Response()
|
||||
response.status_code = 101
|
||||
response.headers = CaseInsensitiveDict({"upgrade": "websocket"})
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
|
||||
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event())
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
|
||||
_swap_context(response)
|
||||
return response # type: ignore
|
||||
|
||||
async def _do_send_sse(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Handle SSE requests via ASGI HTTP streaming with SSE parsing."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._sse import ASGISSEExtension
|
||||
|
||||
# Convert scheme: sse:// -> https://, psse:// -> http://
|
||||
parsed = urlparse(request.url)
|
||||
if parsed.scheme == "sse":
|
||||
http_scheme = "https"
|
||||
else:
|
||||
http_scheme = "http"
|
||||
|
||||
# Build a modified request with http scheme for scope creation
|
||||
original_url = request.url
|
||||
request.url = request.url.replace(f"{parsed.scheme}://", f"{http_scheme}://", 1) # type: ignore[union-attr,str-bytes-safe]
|
||||
|
||||
scope = self._create_scope(request)
|
||||
request.url = original_url
|
||||
|
||||
body = request.body or b""
|
||||
if isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
ext = ASGISSEExtension()
|
||||
start_message = await ext.start(self.app, scope, body) # type: ignore[arg-type]
|
||||
|
||||
status_code = start_message["status"]
|
||||
response_headers = start_message.get("headers", [])
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = original_url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
|
||||
raw_io = _ASGIRawIO(asyncio.Queue(), asyncio.Event(), timeout)
|
||||
raw_io.headers = headers_dict
|
||||
raw_io.extension = ext
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
|
||||
_swap_context(response)
|
||||
return response # type: ignore
|
||||
|
||||
async def _do_send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool,
|
||||
timeout: int | float | None,
|
||||
) -> AsyncResponse:
|
||||
"""Perform the actual ASGI request."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
return await self._do_send_ws(request)
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return await self._do_send_sse(request, timeout)
|
||||
|
||||
scope = self._create_scope(request)
|
||||
|
||||
body = request.body or b""
|
||||
body_iter: typing.AsyncIterator[bytes] | typing.AsyncIterator[str] | None = None
|
||||
|
||||
# Check if body is an async iterable
|
||||
if hasattr(body, "__aiter__"):
|
||||
body_iter = body.__aiter__()
|
||||
body = b"" # Will be streamed
|
||||
elif isinstance(body, str):
|
||||
body = body.encode("utf-8")
|
||||
|
||||
request_complete = False
|
||||
response_complete = asyncio.Event()
|
||||
response_queue: asyncio.Queue[ASGIMessage | None] = asyncio.Queue()
|
||||
app_exception: Exception | None = None
|
||||
|
||||
async def receive() -> ASGIMessage:
|
||||
nonlocal request_complete
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
if body_iter is not None:
|
||||
# Stream chunks from async iterable
|
||||
try:
|
||||
chunk = await body_iter.__anext__()
|
||||
if isinstance(chunk, str):
|
||||
chunk = chunk.encode("utf-8")
|
||||
return {"type": "http.request", "body": chunk, "more_body": True}
|
||||
except StopAsyncIteration:
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
else:
|
||||
# Single body chunk
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def send_func(message: ASGIMessage) -> None:
|
||||
await response_queue.put(message)
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", False):
|
||||
response_complete.set()
|
||||
|
||||
async def run_app() -> None:
|
||||
nonlocal app_exception
|
||||
try:
|
||||
await self.app(scope, receive, send_func)
|
||||
except Exception as ex:
|
||||
app_exception = ex
|
||||
finally:
|
||||
await response_queue.put(None)
|
||||
|
||||
if stream:
|
||||
return await self._stream_response(
|
||||
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
|
||||
)
|
||||
else:
|
||||
return await self._buffered_response(
|
||||
request, response_queue, response_complete, run_app, lambda: app_exception, timeout
|
||||
)
|
||||
|
||||
async def _stream_response(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
run_app: typing.Callable[[], typing.Awaitable[None]],
|
||||
get_exception: typing.Callable[[], Exception | None],
|
||||
timeout: float | None,
|
||||
) -> AsyncResponse:
|
||||
status_code: int | None = None
|
||||
response_headers: list[tuple[bytes, bytes]] = []
|
||||
|
||||
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
|
||||
|
||||
try:
|
||||
# Wait for http.response.start with timeout
|
||||
async with asyncio_timeout(timeout):
|
||||
while True:
|
||||
message = await response_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
break
|
||||
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
raw_io = _ASGIRawIO(response_queue, response_complete, timeout)
|
||||
raw_io.headers = headers_dict
|
||||
raw_io._task = task
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
response.raw = raw_io # type: ignore
|
||||
response._content = False
|
||||
response._content_consumed = False
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise ConnectTimeout("Timed out waiting for ASGI response headers")
|
||||
|
||||
except Exception:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise
|
||||
|
||||
async def _buffered_response(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
response_queue: asyncio.Queue[ASGIMessage | None],
|
||||
response_complete: asyncio.Event,
|
||||
run_app: typing.Callable[[], typing.Awaitable[None]],
|
||||
get_exception: typing.Callable[[], Exception | None],
|
||||
timeout: float | None,
|
||||
) -> AsyncResponse:
|
||||
status_code: int | None = None
|
||||
response_headers: list[tuple[bytes, bytes]] = []
|
||||
body_chunks: list[bytes] = []
|
||||
|
||||
task = asyncio.create_task(run_app()) # type: ignore[var-annotated,arg-type]
|
||||
|
||||
try:
|
||||
async with asyncio_timeout(timeout):
|
||||
while True:
|
||||
message = await response_queue.get()
|
||||
if message is None:
|
||||
break
|
||||
if message["type"] == "http.response.start":
|
||||
status_code = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
elif message["type"] == "http.response.body":
|
||||
chunk = message.get("body", b"")
|
||||
if chunk:
|
||||
body_chunks.append(chunk)
|
||||
|
||||
await task
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise ReadTimeout("Timed out reading ASGI response body")
|
||||
|
||||
except Exception:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
raise
|
||||
|
||||
if self.raise_app_exceptions and get_exception() is not None:
|
||||
raise get_exception() # type: ignore
|
||||
|
||||
headers_dict = {k.decode("latin-1"): v.decode("latin-1") for k, v in response_headers}
|
||||
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response.headers = CaseInsensitiveDict(headers_dict)
|
||||
response.request = request
|
||||
response.url = request.url
|
||||
response.encoding = response.headers.get("content-type", "utf-8") # type: ignore[assignment]
|
||||
response._content = b"".join(body_chunks)
|
||||
response.raw = _ASGIRawIO(response_queue, response_complete, timeout) # type: ignore
|
||||
response.raw.headers = headers_dict
|
||||
|
||||
_swap_context(response)
|
||||
|
||||
return response # type: ignore[return-value]
|
||||
|
||||
def _create_scope(self, request: PreparedRequest) -> dict:
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scope: dict[str, typing.Any] = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": request.method,
|
||||
"scheme": "http",
|
||||
"path": unquote(parsed.path) or "/",
|
||||
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
|
||||
"root_path": "",
|
||||
"headers": headers,
|
||||
"server": (
|
||||
parsed.hostname or "localhost",
|
||||
parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
),
|
||||
}
|
||||
|
||||
# Include lifespan state if available (for frameworks like Starlette that use it)
|
||||
if self._lifespan_state is not None:
|
||||
scope["state"] = self._lifespan_state.copy()
|
||||
|
||||
return scope
|
||||
|
||||
def _create_ws_scope(self, request: PreparedRequest) -> dict:
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
headers: list[tuple[bytes, bytes]] = []
|
||||
if request.headers:
|
||||
for key, value in request.headers.items():
|
||||
headers.append((key.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scheme = "wss" if parsed.scheme == "wss" else "ws"
|
||||
|
||||
scope: dict[str, typing.Any] = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"scheme": scheme,
|
||||
"path": unquote(parsed.path) or "/",
|
||||
"query_string": (parsed.query or "").encode("latin-1"), # type: ignore[union-attr]
|
||||
"root_path": "",
|
||||
"headers": headers,
|
||||
"server": (
|
||||
parsed.hostname or "localhost",
|
||||
parsed.port or (443 if scheme == "wss" else 80),
|
||||
),
|
||||
}
|
||||
|
||||
if self._lifespan_state is not None:
|
||||
scope["state"] = self._lifespan_state.copy()
|
||||
|
||||
return scope
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ThreadAsyncServerGatewayInterface(BaseAdapter):
|
||||
"""Synchronous adapter for ASGI applications using a background event loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
raise_app_exceptions: bool = True,
|
||||
max_retries: RetryType = DEFAULT_RETRIES,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.app = app
|
||||
self.raise_app_exceptions = raise_app_exceptions
|
||||
|
||||
if isinstance(max_retries, Retry):
|
||||
self.max_retries = max_retries
|
||||
else:
|
||||
self.max_retries = Retry.from_int(max_retries)
|
||||
|
||||
self._async_adapter: AsyncServerGatewayInterface | None = None
|
||||
self._loop: typing.Any = None # asyncio.AbstractEventLoop
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = threading.Event()
|
||||
self._lifespan_task: asyncio.Task | None = None
|
||||
self._lifespan_receive_queue: asyncio.Queue[ASGIMessage] | None = None
|
||||
self._lifespan_startup_complete = threading.Event()
|
||||
self._lifespan_startup_failed: Exception | None = None
|
||||
self._lifespan_state: dict[str, typing.Any] = {}
|
||||
self._startup_lock: threading.Lock = threading.Lock()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<ASGIAdapter Thread/>"
|
||||
|
||||
def _ensure_loop_running(self) -> None:
|
||||
"""Start the background event loop thread if not already running."""
|
||||
with self._startup_lock:
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
|
||||
import asyncio
|
||||
|
||||
def run_loop() -> None:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._lifespan_receive_queue = asyncio.Queue()
|
||||
self._async_adapter = AsyncServerGatewayInterface(
|
||||
self.app,
|
||||
raise_app_exceptions=self.raise_app_exceptions,
|
||||
max_retries=self.max_retries,
|
||||
lifespan_state=self._lifespan_state,
|
||||
)
|
||||
# Start lifespan handler
|
||||
self._lifespan_task = self._loop.create_task(self._handle_lifespan())
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait()
|
||||
self._lifespan_startup_complete.wait()
|
||||
|
||||
if self._lifespan_startup_failed is not None:
|
||||
raise self._lifespan_startup_failed
|
||||
|
||||
async def _handle_lifespan(self) -> None:
|
||||
"""Handle ASGI lifespan protocol."""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0"},
|
||||
"state": self._lifespan_state,
|
||||
}
|
||||
|
||||
startup_complete = asyncio.Event()
|
||||
shutdown_complete = asyncio.Event()
|
||||
startup_failed: list[Exception] = []
|
||||
|
||||
# Keep local reference to avoid race condition during shutdown
|
||||
receive_queue = self._lifespan_receive_queue
|
||||
|
||||
async def receive() -> ASGIMessage:
|
||||
return await receive_queue.get() # type: ignore[union-attr]
|
||||
|
||||
async def send(message: ASGIMessage) -> None:
|
||||
if message["type"] == "lifespan.startup.complete":
|
||||
startup_complete.set()
|
||||
elif message["type"] == "lifespan.startup.failed":
|
||||
startup_failed.append(RuntimeError(message.get("message", "Lifespan startup failed")))
|
||||
startup_complete.set()
|
||||
elif message["type"] == "lifespan.shutdown.complete":
|
||||
shutdown_complete.set()
|
||||
elif message["type"] == "lifespan.shutdown.failed":
|
||||
shutdown_complete.set()
|
||||
|
||||
async def run_lifespan() -> None:
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
except Exception as e:
|
||||
if not startup_complete.is_set():
|
||||
startup_failed.append(e)
|
||||
startup_complete.set()
|
||||
|
||||
lifespan_task = asyncio.create_task(run_lifespan())
|
||||
|
||||
# Send startup event
|
||||
await receive_queue.put({"type": "lifespan.startup"}) # type: ignore[union-attr]
|
||||
await startup_complete.wait()
|
||||
|
||||
if startup_failed:
|
||||
self._lifespan_startup_failed = startup_failed[0]
|
||||
self._lifespan_startup_complete.set()
|
||||
|
||||
# Wait for shutdown signal (loop.stop() will cancel this)
|
||||
try:
|
||||
await asyncio.Future() # Wait forever until canceled
|
||||
except (asyncio.CancelledError, GeneratorExit):
|
||||
pass
|
||||
|
||||
# Send shutdown event - must happen before loop stops
|
||||
if receive_queue is not None:
|
||||
try:
|
||||
await receive_queue.put({"type": "lifespan.shutdown"})
|
||||
await asyncio.wait_for(shutdown_complete.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError, RuntimeError):
|
||||
pass
|
||||
lifespan_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await lifespan_task
|
||||
|
||||
def _do_send_ws(self, request: PreparedRequest) -> Response:
|
||||
"""Handle WebSocket requests synchronously via the background loop."""
|
||||
from .._ws import ThreadASGIWebSocketExtension
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_ws() -> None:
|
||||
try:
|
||||
result = await self._async_adapter._do_send_ws(request) # type: ignore[union-attr]
|
||||
_swap_context(result)
|
||||
|
||||
# Wrap the async WS extension in a sync wrapper
|
||||
async_ext = result.raw.extension # type: ignore[union-attr]
|
||||
result.raw.extension = ThreadASGIWebSocketExtension(async_ext, self._loop) # type: ignore[union-attr]
|
||||
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_ws()))
|
||||
return future.result()
|
||||
|
||||
def _do_send_sse(self, request: PreparedRequest, timeout: int | float | None = None) -> Response:
|
||||
"""Handle SSE requests synchronously via the background loop."""
|
||||
from .._sse import ThreadASGISSEExtension
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_sse() -> None:
|
||||
try:
|
||||
result = await self._async_adapter._do_send_sse(request, timeout) # type: ignore[union-attr]
|
||||
_swap_context(result)
|
||||
|
||||
# Wrap the async SSE extension in a sync wrapper
|
||||
async_ext = result.raw.extension # type: ignore[union-attr]
|
||||
result.raw.extension = ThreadASGISSEExtension(async_ext, self._loop) # type: ignore[union-attr]
|
||||
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_sse()))
|
||||
return future.result()
|
||||
|
||||
def send(
|
||||
self,
|
||||
request: PreparedRequest,
|
||||
stream: bool = False,
|
||||
timeout: int | float | tuple | TimeoutSauce | None = None,
|
||||
verify: TLSVerifyType = True,
|
||||
cert: TLSClientCertType | None = None,
|
||||
proxies: ProxyType | None = None,
|
||||
on_post_connection: typing.Callable[[typing.Any], None] | None = None,
|
||||
on_upload_body: typing.Callable[[int, int | None, bool, bool], None] | None = None,
|
||||
on_early_response: typing.Callable[[Response], None] | None = None,
|
||||
multiplexed: bool = False,
|
||||
) -> Response:
|
||||
"""Send a PreparedRequest to the ASGI application synchronously."""
|
||||
if isinstance(timeout, tuple):
|
||||
if len(timeout) == 3:
|
||||
timeout = timeout[2] or timeout[0] # prefer total, fallback connect
|
||||
else:
|
||||
timeout = timeout[0] # use connect
|
||||
elif isinstance(timeout, TimeoutSauce):
|
||||
timeout = timeout.total or timeout.connect_timeout
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if parsed.scheme in ("ws", "wss"):
|
||||
return self._do_send_ws(request)
|
||||
|
||||
if parsed.scheme in ("sse", "psse"):
|
||||
return self._do_send_sse(request, timeout)
|
||||
|
||||
if stream:
|
||||
raise ValueError(
|
||||
"ThreadAsyncServerGatewayInterface does not support streaming responses. "
|
||||
"Use stream=False or migrate to pure async/await implementation."
|
||||
)
|
||||
|
||||
self._ensure_loop_running()
|
||||
|
||||
future: Future[Response] = Future()
|
||||
|
||||
async def run_send() -> None:
|
||||
try:
|
||||
result = await self._async_adapter.send( # type: ignore[union-attr]
|
||||
request,
|
||||
stream=False,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
proxies=proxies,
|
||||
)
|
||||
_swap_context(result)
|
||||
future.set_result(result) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(run_send()))
|
||||
|
||||
return future.result()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up adapter resources."""
|
||||
if self._loop is not None and self._lifespan_task is not None:
|
||||
# Signal shutdown and wait for it to complete
|
||||
shutdown_done = threading.Event()
|
||||
|
||||
async def do_shutdown() -> None:
|
||||
if self._lifespan_task is not None:
|
||||
self._lifespan_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._lifespan_task
|
||||
shutdown_done.set()
|
||||
|
||||
self._loop.call_soon_threadsafe(lambda: self._loop.create_task(do_shutdown()))
|
||||
shutdown_done.wait(timeout=6.0)
|
||||
|
||||
if self._loop is not None:
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=5.0)
|
||||
self._thread = None
|
||||
# Clear resources only after thread has stopped
|
||||
self._loop = None
|
||||
self._async_adapter = None
|
||||
self._lifespan_task = None
|
||||
self._lifespan_receive_queue = None
|
||||
self._started.clear()
|
||||
self._lifespan_startup_complete.clear()
|
||||
self._lifespan_startup_failed = None
|
||||
|
||||
|
||||
__all__ = ("AsyncServerGatewayInterface", "ThreadAsyncServerGatewayInterface")
|
||||
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
from ....packages.urllib3.contrib.webextensions.sse import ServerSentEvent
|
||||
|
||||
|
||||
class ASGISSEExtension:
|
||||
"""Async SSE extension for ASGI applications.
|
||||
|
||||
Runs the ASGI app as a normal HTTP streaming request and parses
|
||||
SSE events from the response body chunks.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._buffer: str = ""
|
||||
self._last_event_id: str | None = None
|
||||
self._response_queue: asyncio.Queue[dict[str, typing.Any] | None] | None = None
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(
|
||||
self,
|
||||
app: typing.Any,
|
||||
scope: dict[str, typing.Any],
|
||||
body: bytes = b"",
|
||||
) -> dict[str, typing.Any]:
|
||||
"""Start the ASGI app and wait for http.response.start.
|
||||
|
||||
Returns the response start message (with status and headers).
|
||||
"""
|
||||
self._response_queue = asyncio.Queue()
|
||||
|
||||
request_complete = False
|
||||
response_complete = asyncio.Event()
|
||||
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
nonlocal request_complete
|
||||
if request_complete:
|
||||
await response_complete.wait()
|
||||
return {"type": "http.disconnect"}
|
||||
request_complete = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def send(message: dict[str, typing.Any]) -> None:
|
||||
await self._response_queue.put(message) # type: ignore[union-attr]
|
||||
if message["type"] == "http.response.body" and not message.get("more_body", False):
|
||||
response_complete.set()
|
||||
|
||||
async def run_app() -> None:
|
||||
try:
|
||||
await app(scope, receive, send)
|
||||
finally:
|
||||
await self._response_queue.put(None) # type: ignore[union-attr]
|
||||
|
||||
self._task = asyncio.create_task(run_app())
|
||||
|
||||
# Wait for http.response.start
|
||||
while True:
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
raise ConnectionError("ASGI app closed before sending response headers")
|
||||
if message["type"] == "http.response.start":
|
||||
return message
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Read and parse the next SSE event from the ASGI response stream.
|
||||
Returns None when the stream ends."""
|
||||
if self._closed:
|
||||
raise OSError("The SSE extension is closed")
|
||||
|
||||
while True:
|
||||
# Check if we already have a complete event in the buffer
|
||||
sep_idx = self._buffer.find("\n\n")
|
||||
if sep_idx == -1:
|
||||
sep_idx = self._buffer.find("\r\n\r\n")
|
||||
if sep_idx != -1:
|
||||
sep_len = 4
|
||||
else:
|
||||
sep_len = 2
|
||||
else:
|
||||
sep_len = 2
|
||||
|
||||
if sep_idx != -1:
|
||||
raw_event = self._buffer[:sep_idx]
|
||||
self._buffer = self._buffer[sep_idx + sep_len :]
|
||||
|
||||
event = self._parse_event(raw_event)
|
||||
|
||||
if event is not None:
|
||||
if raw:
|
||||
return raw_event + "\n\n"
|
||||
return event
|
||||
|
||||
# Empty event (e.g. just comments), try next
|
||||
continue
|
||||
|
||||
# Need more data from the ASGI response queue
|
||||
chunk = await self._read_chunk()
|
||||
if chunk is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
self._buffer += chunk
|
||||
|
||||
async def _read_chunk(self) -> str | None:
|
||||
"""Read the next body chunk from the ASGI response."""
|
||||
if self._response_queue is None:
|
||||
return None
|
||||
|
||||
message = await self._response_queue.get()
|
||||
if message is None:
|
||||
return None
|
||||
if message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
if body:
|
||||
return body.decode("utf-8")
|
||||
return None
|
||||
|
||||
def _parse_event(self, raw_event: str) -> ServerSentEvent | None:
|
||||
"""Parse a raw SSE event block into a ServerSentEvent."""
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
|
||||
for line in raw_event.splitlines():
|
||||
if not line or line.startswith(":"):
|
||||
continue
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if not kwargs:
|
||||
return None
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""SSE is one-way only."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the SSE stream and clean up the app task."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import typing
|
||||
|
||||
|
||||
class ASGIWebSocketExtension:
|
||||
"""Async WebSocket extension for ASGI applications.
|
||||
|
||||
Uses the ASGI websocket protocol with send/receive queues to communicate
|
||||
with the application task.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._closed = False
|
||||
self._app_send_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
|
||||
self._app_receive_queue: asyncio.Queue[dict[str, typing.Any]] = asyncio.Queue()
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self, app: typing.Any, scope: dict[str, typing.Any]) -> None:
|
||||
"""Start the ASGI app task and perform the WebSocket handshake."""
|
||||
|
||||
async def receive() -> dict[str, typing.Any]:
|
||||
return await self._app_receive_queue.get()
|
||||
|
||||
async def send(message: dict[str, typing.Any]) -> None:
|
||||
await self._app_send_queue.put(message)
|
||||
|
||||
self._task = asyncio.create_task(app(scope, receive, send))
|
||||
|
||||
# Send connect and wait for accept
|
||||
await self._app_receive_queue.put({"type": "websocket.connect"})
|
||||
|
||||
message = await self._app_send_queue.get()
|
||||
|
||||
if message["type"] == "websocket.close":
|
||||
self._closed = True
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
raise ConnectionError(f"WebSocket connection rejected with code {message.get('code', 1000)}")
|
||||
|
||||
if message["type"] != "websocket.accept":
|
||||
self._closed = True
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
raise ConnectionError(f"Unexpected ASGI message during handshake: {message['type']}")
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Await the next message from the ASGI WebSocket app.
|
||||
Returns None when the app closes the connection."""
|
||||
if self._closed:
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
message = await self._app_send_queue.get()
|
||||
|
||||
if message["type"] == "websocket.send":
|
||||
if "text" in message:
|
||||
return message["text"]
|
||||
if "bytes" in message:
|
||||
return message["bytes"]
|
||||
return b""
|
||||
|
||||
if message["type"] == "websocket.close":
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Send a message to the ASGI WebSocket app."""
|
||||
if self._closed:
|
||||
raise OSError("The WebSocket extension is closed")
|
||||
|
||||
if isinstance(buf, (bytes, bytearray)):
|
||||
await self._app_receive_queue.put({"type": "websocket.receive", "bytes": bytes(buf)})
|
||||
else:
|
||||
await self._app_receive_queue.put({"type": "websocket.receive", "text": buf})
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket and clean up the app task."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
await self._app_receive_queue.put({"type": "websocket.disconnect", "code": 1000})
|
||||
|
||||
if self._task is not None and not self._task.done():
|
||||
self._task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._task
|
||||
Reference in New Issue
Block a user