fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .protocol import ExtensionFromHTTP
|
||||
from .raw import RawExtensionFromHTTP
|
||||
from .sse import ServerSideEventExtensionFromHTTP
|
||||
|
||||
try:
|
||||
from .ws import WebSocketExtensionFromHTTP, WebSocketExtensionFromMultiplexedHTTP
|
||||
except ImportError:
|
||||
WebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
|
||||
WebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def recursive_subclasses(cls: type[T]) -> list[type[T]]:
|
||||
all_subclasses = []
|
||||
|
||||
for subclass in cls.__subclasses__():
|
||||
all_subclasses.append(subclass)
|
||||
all_subclasses.extend(recursive_subclasses(subclass))
|
||||
|
||||
return all_subclasses
|
||||
|
||||
|
||||
def load_extension(
|
||||
scheme: str | None, implementation: str | None = None
|
||||
) -> type[ExtensionFromHTTP]:
|
||||
if scheme is None:
|
||||
return RawExtensionFromHTTP
|
||||
|
||||
scheme = scheme.lower()
|
||||
|
||||
if implementation:
|
||||
implementation = implementation.lower()
|
||||
|
||||
for extension in recursive_subclasses(ExtensionFromHTTP):
|
||||
if scheme in extension.supported_schemes():
|
||||
if (
|
||||
implementation is not None
|
||||
and extension.implementation() != implementation
|
||||
):
|
||||
continue
|
||||
return extension
|
||||
|
||||
raise ImportError(
|
||||
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"ExtensionFromHTTP",
|
||||
"RawExtensionFromHTTP",
|
||||
"WebSocketExtensionFromHTTP",
|
||||
"WebSocketExtensionFromMultiplexedHTTP",
|
||||
"ServerSideEventExtensionFromHTTP",
|
||||
"load_extension",
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
from .raw import AsyncRawExtensionFromHTTP
|
||||
from .sse import AsyncServerSideEventExtensionFromHTTP
|
||||
|
||||
try:
|
||||
from .ws import (
|
||||
AsyncWebSocketExtensionFromHTTP,
|
||||
AsyncWebSocketExtensionFromMultiplexedHTTP,
|
||||
)
|
||||
except ImportError:
|
||||
AsyncWebSocketExtensionFromHTTP = None # type: ignore[misc, assignment]
|
||||
AsyncWebSocketExtensionFromMultiplexedHTTP = None # type: ignore[misc, assignment]
|
||||
|
||||
from .. import recursive_subclasses
|
||||
|
||||
|
||||
def load_extension(
|
||||
scheme: str | None, implementation: str | None = None
|
||||
) -> type[AsyncExtensionFromHTTP]:
|
||||
if scheme is None:
|
||||
return AsyncRawExtensionFromHTTP
|
||||
|
||||
scheme = scheme.lower()
|
||||
|
||||
if implementation:
|
||||
implementation = implementation.lower()
|
||||
|
||||
for extension in recursive_subclasses(AsyncExtensionFromHTTP):
|
||||
if scheme in extension.supported_schemes():
|
||||
if (
|
||||
implementation is not None
|
||||
and extension.implementation() != implementation
|
||||
):
|
||||
continue
|
||||
return extension
|
||||
|
||||
raise ImportError(
|
||||
f"Tried to load HTTP extension '{scheme}' but no available plugin support it."
|
||||
)
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AsyncExtensionFromHTTP",
|
||||
"AsyncRawExtensionFromHTTP",
|
||||
"AsyncWebSocketExtensionFromHTTP",
|
||||
"AsyncWebSocketExtensionFromMultiplexedHTTP",
|
||||
"AsyncServerSideEventExtensionFromHTTP",
|
||||
"load_extension",
|
||||
)
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from contextlib import asynccontextmanager
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...._async.response import AsyncHTTPResponse
|
||||
from ....backend import HttpVersion
|
||||
from ....backend._async._base import AsyncDirectStreamAccess
|
||||
from ....util._async.traffic_police import AsyncTrafficPolice
|
||||
|
||||
from ....exceptions import (
|
||||
BaseSSLError,
|
||||
ProtocolError,
|
||||
ReadTimeoutError,
|
||||
SSLError,
|
||||
MustRedialError,
|
||||
)
|
||||
|
||||
|
||||
class AsyncExtensionFromHTTP(metaclass=ABCMeta):
|
||||
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
|
||||
This will considerably ease downstream integration."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dsa: AsyncDirectStreamAccess | None = None
|
||||
self._response: AsyncHTTPResponse | None = None
|
||||
self._police_officer: AsyncTrafficPolice | None = None # type: ignore[type-arg]
|
||||
|
||||
@asynccontextmanager
|
||||
async def _read_error_catcher(self) -> typing.AsyncGenerator[None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
clean_exit = True
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
# FIXME: Is there a better way to differentiate between SSLErrors?
|
||||
if "read operation timed out" not in str(e):
|
||||
# SSL errors related to framing/MAC get wrapped and reraised here
|
||||
raise SSLError(e) from e
|
||||
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except (OSError, MustRedialError) as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
await self.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _write_error_catcher(self) -> typing.AsyncGenerator[None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
raise SSLError(e) from e
|
||||
|
||||
except OSError as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
return {}
|
||||
|
||||
async def start(self, response: AsyncHTTPResponse) -> None:
|
||||
"""The HTTP server gave us the go-to start negotiating another protocol."""
|
||||
if response._fp is None or not hasattr(response._fp, "_dsa"):
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
self._dsa = response._fp._dsa
|
||||
self._police_officer = response._police_officer
|
||||
self._response = response
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._dsa is None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
"""Hint about supported parent SVN for this extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
"""Recognized schemes for the extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
"""Convert the extension scheme to a known http scheme (either http or https)"""
|
||||
raise NotImplementedError
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote. This call does read from the socket.
|
||||
If the method return None, it means that the remote closed the (extension) pipeline.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def on_payload(
|
||||
self, callback: typing.Callable[[str | bytes | None], typing.Awaitable[None]]
|
||||
) -> None:
|
||||
"""Set up a callback that will be invoked automatically once a payload is received.
|
||||
Meaning that you stop calling manually next_payload()."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....backend import HttpVersion
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
|
||||
|
||||
class AsyncRawExtensionFromHTTP(AsyncExtensionFromHTTP):
|
||||
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
return {}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
await self._dsa.close()
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
await self._response.close()
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "raw"
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return set()
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return scheme
|
||||
|
||||
async def next_payload(self) -> bytes | None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
async with self._police_officer.borrow(self._response):
|
||||
async with self._read_error_catcher():
|
||||
data, eot, _ = await self._dsa.recv_extended(None)
|
||||
return data
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
if isinstance(buf, str):
|
||||
buf = buf.encode()
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(buf)
|
||||
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...._async.response import AsyncHTTPResponse
|
||||
|
||||
from ....backend import HttpVersion
|
||||
from ..sse import ServerSentEvent
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
|
||||
|
||||
class AsyncServerSideEventExtensionFromHTTP(AsyncExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._last_event_id: str | None = None
|
||||
self._buffer: str = ""
|
||||
self._stream: typing.AsyncGenerator[bytes, None] | None = None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "native"
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
return {"preload_content": False}
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._stream is not None and self._response is not None:
|
||||
await self._stream.aclose()
|
||||
if (
|
||||
self._response._fp is not None
|
||||
and self._police_officer is not None
|
||||
and hasattr(self._response._fp, "abort")
|
||||
):
|
||||
async with self._police_officer.borrow(self._response):
|
||||
await self._response._fp.abort()
|
||||
self._stream = None
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._stream is None
|
||||
|
||||
async def start(self, response: AsyncHTTPResponse) -> None:
|
||||
await super().start(response)
|
||||
|
||||
self._stream = response.stream(-1, decode_content=True)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
return {"accept": "text/event-stream", "cache-control": "no-store"}
|
||||
|
||||
@typing.overload
|
||||
async def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
async def next_payload(
|
||||
self, *, raw: typing.Literal[False] = False
|
||||
) -> ServerSentEvent | None: ...
|
||||
|
||||
async def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._response is None or self._stream is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
try:
|
||||
raw_payload: str = (await self._stream.__anext__()).decode("utf-8")
|
||||
except StopAsyncIteration:
|
||||
await self._stream.aclose()
|
||||
self._stream = None
|
||||
return None
|
||||
|
||||
if self._buffer:
|
||||
raw_payload = self._buffer + raw_payload
|
||||
self._buffer = ""
|
||||
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
eot = False
|
||||
|
||||
for line in raw_payload.splitlines():
|
||||
if not line:
|
||||
eot = True
|
||||
break
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if eot is False:
|
||||
self._buffer = raw_payload
|
||||
return await self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
if raw is True:
|
||||
return raw_payload
|
||||
|
||||
return event
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"sse", "psse"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"sse": "https", "psse": "http"}[scheme]
|
||||
@@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...._async.response import AsyncHTTPResponse
|
||||
|
||||
from wsproto import ConnectionType, WSConnection
|
||||
from wsproto.events import (
|
||||
AcceptConnection,
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Ping,
|
||||
Pong,
|
||||
Request,
|
||||
TextMessage,
|
||||
)
|
||||
from wsproto.extensions import PerMessageDeflate
|
||||
from wsproto.utilities import ProtocolError as WebSocketProtocolError
|
||||
|
||||
from ....backend import HttpVersion
|
||||
from ....exceptions import ProtocolError
|
||||
from .protocol import AsyncExtensionFromHTTP
|
||||
|
||||
|
||||
class AsyncWebSocketExtensionFromHTTP(AsyncExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._protocol = WSConnection(ConnectionType.CLIENT)
|
||||
self._request_headers: dict[str, str] | None = None
|
||||
self._remote_shutdown: bool = False
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "wsproto"
|
||||
|
||||
async def start(self, response: AsyncHTTPResponse) -> None:
|
||||
await super().start(response)
|
||||
|
||||
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||||
|
||||
fake_http_response += b"Sec-Websocket-Accept: "
|
||||
|
||||
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
|
||||
|
||||
if accept_token is None:
|
||||
raise ProtocolError(
|
||||
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
|
||||
)
|
||||
|
||||
fake_http_response += accept_token.encode() + b"\r\n"
|
||||
|
||||
if "sec-websocket-extensions" in response.headers:
|
||||
fake_http_response += (
|
||||
b"Sec-Websocket-Extensions: "
|
||||
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
|
||||
+ b"\r\n"
|
||||
)
|
||||
|
||||
fake_http_response += b"\r\n"
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(fake_http_response)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never occur!
|
||||
|
||||
event = next(self._protocol.events())
|
||||
|
||||
if not isinstance(event, AcceptConnection):
|
||||
raise RuntimeError(
|
||||
"The WebSocket state-machine did not pass the handshake phase when expected."
|
||||
)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
if self._request_headers is not None:
|
||||
return self._request_headers
|
||||
|
||||
try:
|
||||
raw_data_to_socket = self._protocol.send(
|
||||
Request(
|
||||
host="example.com", target="/", extensions=(PerMessageDeflate(),)
|
||||
)
|
||||
)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never occur!
|
||||
|
||||
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
|
||||
request_headers: dict[str, str] = {}
|
||||
|
||||
for raw_header in raw_headers:
|
||||
k, v = raw_header.decode().split(": ")
|
||||
request_headers[k.lower()] = v
|
||||
|
||||
if http_version != HttpVersion.h11:
|
||||
del request_headers["upgrade"]
|
||||
del request_headers["connection"]
|
||||
request_headers[":protocol"] = "websocket"
|
||||
request_headers[":method"] = "CONNECT"
|
||||
|
||||
self._request_headers = request_headers
|
||||
|
||||
return request_headers
|
||||
|
||||
async def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
if self._police_officer is not None:
|
||||
async with self._police_officer.borrow(self._response):
|
||||
if self._remote_shutdown is False:
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(
|
||||
CloseConnection(0)
|
||||
)
|
||||
except WebSocketProtocolError:
|
||||
pass
|
||||
else:
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
await self._dsa.close()
|
||||
self._dsa = None
|
||||
else:
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
if self._police_officer is not None:
|
||||
self._police_officer.forget(self._response)
|
||||
else:
|
||||
await self._response.close()
|
||||
self._response = None
|
||||
|
||||
self._police_officer = None
|
||||
|
||||
async def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
await self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(event.response())
|
||||
except WebSocketProtocolError as e:
|
||||
await self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
|
||||
while True:
|
||||
async with self._read_error_catcher():
|
||||
data, eot, _ = await self._dsa.recv_extended(None)
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(data)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e
|
||||
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
await self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
data_to_send = self._protocol.send(event.response())
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
elif isinstance(event, Pong):
|
||||
continue
|
||||
|
||||
async def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
try:
|
||||
if isinstance(buf, str):
|
||||
data_to_send: bytes = self._protocol.send(TextMessage(buf))
|
||||
else:
|
||||
data_to_send = self._protocol.send(BytesMessage(buf))
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e
|
||||
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
|
||||
async def ping(self) -> None:
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
async with self._police_officer.borrow(self._response):
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(Ping())
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e
|
||||
|
||||
async with self._write_error_catcher():
|
||||
await self._dsa.sendall(data_to_send)
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"ws", "wss"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http", "wss": "https"}[scheme]
|
||||
|
||||
|
||||
class AsyncWebSocketExtensionFromMultiplexedHTTP(AsyncWebSocketExtensionFromHTTP):
|
||||
"""
|
||||
Plugin that support doing WebSocket over HTTP 2 and 3.
|
||||
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "rfc8441"
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from abc import ABCMeta
|
||||
from contextlib import contextmanager
|
||||
from socket import timeout as SocketTimeout
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...backend import HttpVersion
|
||||
from ...backend._base import DirectStreamAccess
|
||||
from ...response import HTTPResponse
|
||||
from ...util.traffic_police import TrafficPolice
|
||||
|
||||
from ...exceptions import (
|
||||
BaseSSLError,
|
||||
ProtocolError,
|
||||
ReadTimeoutError,
|
||||
SSLError,
|
||||
MustRedialError,
|
||||
)
|
||||
|
||||
|
||||
class ExtensionFromHTTP(metaclass=ABCMeta):
|
||||
"""Represent an extension that can be negotiated just after a "101 Switching Protocol" HTTP response.
|
||||
This will considerably ease downstream integration."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dsa: DirectStreamAccess | None = None
|
||||
self._response: HTTPResponse | None = None
|
||||
self._police_officer: TrafficPolice | None = None # type: ignore[type-arg]
|
||||
|
||||
@contextmanager
|
||||
def _read_error_catcher(self) -> typing.Generator[None, None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
clean_exit = True
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
# FIXME: Is there a better way to differentiate between SSLErrors?
|
||||
if "read operation timed out" not in str(e):
|
||||
# SSL errors related to framing/MAC get wrapped and reraised here
|
||||
raise SSLError(e) from e
|
||||
clean_exit = True # ws algorithms based on timeouts can expect this without being harmful!
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except (OSError, MustRedialError) as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
self.close()
|
||||
|
||||
@contextmanager
|
||||
def _write_error_catcher(self) -> typing.Generator[None, None, None]:
|
||||
"""
|
||||
Catch low-level python exceptions, instead re-raising urllib3
|
||||
variants, so that low-level exceptions are not leaked in the
|
||||
high-level api.
|
||||
|
||||
On unrecoverable issues, release the connection back to the pool.
|
||||
"""
|
||||
clean_exit = False
|
||||
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
|
||||
except SocketTimeout as e:
|
||||
pool = (
|
||||
self._response._pool
|
||||
if self._response and hasattr(self._response, "_pool")
|
||||
else None
|
||||
)
|
||||
raise ReadTimeoutError(pool, None, "Read timed out.") from e # type: ignore[arg-type]
|
||||
|
||||
except BaseSSLError as e:
|
||||
raise SSLError(e) from e
|
||||
|
||||
except OSError as e:
|
||||
# This includes IncompleteRead.
|
||||
raise ProtocolError(f"Connection broken: {e!r}", e) from e
|
||||
|
||||
# If no exception is thrown, we should avoid cleaning up
|
||||
# unnecessarily.
|
||||
clean_exit = True
|
||||
finally:
|
||||
# If we didn't terminate cleanly, we need to throw away our
|
||||
# connection.
|
||||
if not clean_exit:
|
||||
# The response may not be closed but we're not going to use it
|
||||
# anymore so close it now to ensure that the connection is
|
||||
# released back to the pool.
|
||||
if self._response:
|
||||
self.close()
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
"""Return prerequisites. Must be passed as additional parameters to urlopen."""
|
||||
return {}
|
||||
|
||||
def start(self, response: HTTPResponse) -> None:
|
||||
"""The HTTP server gave us the go-to start negotiating another protocol."""
|
||||
if response._fp is None or not hasattr(response._fp, "_dsa"):
|
||||
raise RuntimeError(
|
||||
"Attempt to start an HTTP extension without direct I/O access to the stream"
|
||||
)
|
||||
|
||||
self._dsa = response._fp._dsa
|
||||
self._police_officer = response._police_officer
|
||||
self._response = response
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._dsa is None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
"""Hint about supported parent SVN for this extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
"""Recognized schemes for the extension."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
"""Convert the extension scheme to a known http scheme (either http or https)"""
|
||||
raise NotImplementedError
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
raise NotImplementedError
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote. This call does read from the socket.
|
||||
If the method return None, it means that the remote closed the (extension) pipeline.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_payload(self, callback: typing.Callable[[str | bytes | None], None]) -> None:
|
||||
"""Set up a callback that will be invoked automatically once a payload is received.
|
||||
Meaning that you stop calling manually next_payload()."""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...backend import HttpVersion
|
||||
from .protocol import ExtensionFromHTTP
|
||||
|
||||
|
||||
class RawExtensionFromHTTP(ExtensionFromHTTP):
|
||||
"""Raw I/O from given HTTP stream after a 101 Switching Protocol Status."""
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
return {}
|
||||
|
||||
def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
with self._write_error_catcher():
|
||||
self._dsa.close()
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
self._response.close()
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "raw"
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return set()
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return scheme
|
||||
|
||||
def next_payload(self) -> bytes | None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
with self._police_officer.borrow(self._response):
|
||||
with self._read_error_catcher():
|
||||
data, eot, _ = self._dsa.recv_extended(None)
|
||||
return data
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
if self._police_officer is None or self._dsa is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
if isinstance(buf, str):
|
||||
buf = buf.encode()
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(buf)
|
||||
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
from threading import RLock
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...response import HTTPResponse
|
||||
|
||||
from ...backend import HttpVersion
|
||||
from .protocol import ExtensionFromHTTP
|
||||
|
||||
|
||||
class ServerSentEvent:
|
||||
def __init__(
|
||||
self,
|
||||
event: str | None = None,
|
||||
data: str | None = None,
|
||||
id: str | None = None,
|
||||
retry: int | None = None,
|
||||
) -> None:
|
||||
if not event:
|
||||
event = "message"
|
||||
|
||||
if data is None:
|
||||
data = ""
|
||||
|
||||
if id is None:
|
||||
id = ""
|
||||
|
||||
self._event = event
|
||||
self._data = data
|
||||
self._id = id
|
||||
self._retry = retry
|
||||
|
||||
@property
|
||||
def event(self) -> str:
|
||||
return self._event
|
||||
|
||||
@property
|
||||
def data(self) -> str:
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def retry(self) -> int | None:
|
||||
return self._retry
|
||||
|
||||
def json(self) -> typing.Any:
|
||||
return json.loads(self.data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
pieces = [f"event={self.event!r}"]
|
||||
if self.data != "":
|
||||
pieces.append(f"data={self.data!r}")
|
||||
if self.id != "":
|
||||
pieces.append(f"id={self.id!r}")
|
||||
if self.retry is not None:
|
||||
pieces.append(f"retry={self.retry!r}")
|
||||
return f"ServerSentEvent({', '.join(pieces)})"
|
||||
|
||||
|
||||
class ServerSideEventExtensionFromHTTP(ExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._last_event_id: str | None = None
|
||||
self._buffer: str = ""
|
||||
self._lock = RLock()
|
||||
self._stream: typing.Generator[bytes, None, None] | None = None
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "native"
|
||||
|
||||
@property
|
||||
def urlopen_kwargs(self) -> dict[str, typing.Any]:
|
||||
return {"preload_content": False}
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._stream is None
|
||||
|
||||
def close(self) -> None:
|
||||
if self._stream is not None and self._response is not None:
|
||||
self._stream.close()
|
||||
if (
|
||||
self._response._fp is not None
|
||||
and self._police_officer is not None
|
||||
and hasattr(self._response._fp, "abort")
|
||||
):
|
||||
with self._police_officer.borrow(self._response):
|
||||
self._response._fp.abort()
|
||||
self._stream = None
|
||||
self._response = None
|
||||
self._police_officer = None
|
||||
|
||||
def start(self, response: HTTPResponse) -> None:
|
||||
super().start(response)
|
||||
|
||||
self._stream = response.stream(-1, decode_content=True)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
return {"accept": "text/event-stream", "cache-control": "no-store"}
|
||||
|
||||
@typing.overload
|
||||
def next_payload(self, *, raw: typing.Literal[True] = True) -> str | None: ...
|
||||
|
||||
@typing.overload
|
||||
def next_payload(
|
||||
self, *, raw: typing.Literal[False] = False
|
||||
) -> ServerSentEvent | None: ...
|
||||
|
||||
def next_payload(self, *, raw: bool = False) -> ServerSentEvent | str | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._response is None or self._stream is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
with self._lock:
|
||||
try:
|
||||
raw_payload: str = next(self._stream).decode("utf-8")
|
||||
except StopIteration:
|
||||
self._stream = None
|
||||
return None
|
||||
|
||||
if self._buffer:
|
||||
raw_payload = self._buffer + raw_payload
|
||||
self._buffer = ""
|
||||
|
||||
kwargs: dict[str, typing.Any] = {}
|
||||
eot = False
|
||||
|
||||
for line in raw_payload.splitlines():
|
||||
if not line:
|
||||
eot = True
|
||||
break
|
||||
key, _, value = line.partition(":")
|
||||
if key not in {"event", "data", "retry", "id"}:
|
||||
continue
|
||||
if value.startswith(" "):
|
||||
value = value[1:]
|
||||
if key == "id":
|
||||
if "\u0000" in value:
|
||||
continue
|
||||
if key == "retry":
|
||||
try:
|
||||
value = int(value) # type: ignore[assignment]
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
kwargs[key] = value
|
||||
|
||||
if eot is False:
|
||||
self._buffer = raw_payload
|
||||
return self.next_payload(raw=raw) # type: ignore[call-overload,no-any-return]
|
||||
|
||||
if "id" not in kwargs and self._last_event_id is not None:
|
||||
kwargs["id"] = self._last_event_id
|
||||
|
||||
event = ServerSentEvent(**kwargs)
|
||||
|
||||
if event.id:
|
||||
self._last_event_id = event.id
|
||||
|
||||
if raw is True:
|
||||
return raw_payload
|
||||
|
||||
return event
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
raise NotImplementedError("SSE is only one-way. Sending is forbidden.")
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"sse", "psse"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"sse": "https", "psse": "http"}[scheme]
|
||||
@@ -0,0 +1,247 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ...response import HTTPResponse
|
||||
|
||||
from wsproto import ConnectionType, WSConnection
|
||||
from wsproto.events import (
|
||||
AcceptConnection,
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Ping,
|
||||
Pong,
|
||||
Request,
|
||||
TextMessage,
|
||||
)
|
||||
from wsproto.extensions import PerMessageDeflate
|
||||
from wsproto.utilities import ProtocolError as WebSocketProtocolError
|
||||
|
||||
from ...backend import HttpVersion
|
||||
from ...exceptions import ProtocolError
|
||||
from .protocol import ExtensionFromHTTP
|
||||
|
||||
|
||||
class WebSocketExtensionFromHTTP(ExtensionFromHTTP):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._protocol = WSConnection(ConnectionType.CLIENT)
|
||||
self._request_headers: dict[str, str] | None = None
|
||||
self._remote_shutdown: bool = False
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11}
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "wsproto"
|
||||
|
||||
def start(self, response: HTTPResponse) -> None:
|
||||
super().start(response)
|
||||
|
||||
fake_http_response = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n"
|
||||
|
||||
fake_http_response += b"Sec-Websocket-Accept: "
|
||||
|
||||
accept_token: str | None = response.headers.get("Sec-Websocket-Accept")
|
||||
|
||||
if accept_token is None:
|
||||
raise ProtocolError(
|
||||
"The WebSocket HTTP extension requires 'Sec-Websocket-Accept' header in the server response but was not present."
|
||||
)
|
||||
|
||||
fake_http_response += accept_token.encode() + b"\r\n"
|
||||
|
||||
if "sec-websocket-extensions" in response.headers:
|
||||
fake_http_response += (
|
||||
b"Sec-Websocket-Extensions: "
|
||||
+ response.headers.get("sec-websocket-extensions").encode() # type: ignore[union-attr]
|
||||
+ b"\r\n"
|
||||
)
|
||||
|
||||
fake_http_response += b"\r\n"
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(fake_http_response)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never happen
|
||||
|
||||
event = next(self._protocol.events())
|
||||
|
||||
if not isinstance(event, AcceptConnection):
|
||||
raise RuntimeError(
|
||||
"The WebSocket state-machine did not pass the handshake phase when expected."
|
||||
)
|
||||
|
||||
def headers(self, http_version: HttpVersion) -> dict[str, str]:
|
||||
"""Specific HTTP headers required (request) before the 101 status response."""
|
||||
if self._request_headers is not None:
|
||||
return self._request_headers
|
||||
|
||||
try:
|
||||
raw_data_to_socket = self._protocol.send(
|
||||
Request(
|
||||
host="example.com", target="/", extensions=(PerMessageDeflate(),)
|
||||
)
|
||||
)
|
||||
except WebSocketProtocolError as e:
|
||||
raise ProtocolError from e # Defensive: should never happen
|
||||
|
||||
raw_headers = raw_data_to_socket.split(b"\r\n")[2:-2]
|
||||
request_headers: dict[str, str] = {}
|
||||
|
||||
for raw_header in raw_headers:
|
||||
k, v = raw_header.decode().split(": ")
|
||||
request_headers[k.lower()] = v
|
||||
|
||||
if http_version != HttpVersion.h11:
|
||||
del request_headers["upgrade"]
|
||||
del request_headers["connection"]
|
||||
request_headers[":protocol"] = "websocket"
|
||||
request_headers[":method"] = "CONNECT"
|
||||
|
||||
self._request_headers = request_headers
|
||||
|
||||
return request_headers
|
||||
|
||||
def close(self) -> None:
|
||||
"""End/Notify close for sub protocol."""
|
||||
if self._dsa is not None:
|
||||
if self._police_officer is not None:
|
||||
with self._police_officer.borrow(self._response):
|
||||
if self._remote_shutdown is False:
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(
|
||||
CloseConnection(0)
|
||||
)
|
||||
except WebSocketProtocolError:
|
||||
pass
|
||||
else:
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
self._dsa.close()
|
||||
self._dsa = None
|
||||
else:
|
||||
self._dsa = None
|
||||
if self._response is not None:
|
||||
if self._police_officer is not None:
|
||||
self._police_officer.forget(self._response)
|
||||
else:
|
||||
self._response.close()
|
||||
self._response = None
|
||||
|
||||
self._police_officer = None
|
||||
|
||||
def next_payload(self) -> str | bytes | None:
|
||||
"""Unpack the next received message/payload from remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
# we may have pending event to unpack!
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(event.response())
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
|
||||
while True:
|
||||
with self._read_error_catcher():
|
||||
data, eot, _ = self._dsa.recv_extended(None)
|
||||
|
||||
try:
|
||||
self._protocol.receive_data(data)
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
for event in self._protocol.events():
|
||||
if isinstance(event, TextMessage):
|
||||
return event.data
|
||||
elif isinstance(event, BytesMessage):
|
||||
return event.data
|
||||
elif isinstance(event, CloseConnection):
|
||||
self._remote_shutdown = True
|
||||
self.close()
|
||||
return None
|
||||
elif isinstance(event, Ping):
|
||||
try:
|
||||
data_to_send = self._protocol.send(event.response())
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
elif isinstance(event, Pong):
|
||||
continue
|
||||
|
||||
def send_payload(self, buf: str | bytes) -> None:
|
||||
"""Dispatch a buffer to remote."""
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
try:
|
||||
if isinstance(buf, str):
|
||||
data_to_send: bytes = self._protocol.send(TextMessage(buf))
|
||||
else:
|
||||
data_to_send = self._protocol.send(BytesMessage(buf))
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
|
||||
def ping(self) -> None:
|
||||
if self._dsa is None or self._response is None or self._police_officer is None:
|
||||
raise OSError("The HTTP extension is closed or uninitialized")
|
||||
|
||||
with self._police_officer.borrow(self._response):
|
||||
if self._remote_shutdown is False:
|
||||
try:
|
||||
data_to_send: bytes = self._protocol.send(Ping())
|
||||
except WebSocketProtocolError as e:
|
||||
self.close()
|
||||
raise ProtocolError from e
|
||||
|
||||
with self._write_error_catcher():
|
||||
self._dsa.sendall(data_to_send)
|
||||
|
||||
@staticmethod
|
||||
def supported_schemes() -> set[str]:
|
||||
return {"ws", "wss"}
|
||||
|
||||
@staticmethod
|
||||
def scheme_to_http_scheme(scheme: str) -> str:
|
||||
return {"ws": "http", "wss": "https"}[scheme]
|
||||
|
||||
|
||||
class WebSocketExtensionFromMultiplexedHTTP(WebSocketExtensionFromHTTP):
|
||||
"""
|
||||
Plugin that support doing WebSocket over HTTP 2 and 3.
|
||||
This implement RFC8441. Beware that this isn't actually supported by much server around internet.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def implementation() -> str:
|
||||
return "rfc8441" # also known as rfc9220 (http3)
|
||||
|
||||
@staticmethod
|
||||
def supported_svn() -> set[HttpVersion]:
|
||||
return {HttpVersion.h11, HttpVersion.h2, HttpVersion.h3}
|
||||
Reference in New Issue
Block a user