fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099
Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,520 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import platform
|
||||
import socket
|
||||
import typing
|
||||
import warnings
|
||||
|
||||
from ._timeout import timeout
|
||||
from ._gro import open_dgram_connection, DatagramReader, DatagramWriter
|
||||
|
||||
StandardTimeoutError = socket.timeout
|
||||
|
||||
try:
|
||||
from concurrent.futures import TimeoutError as FutureTimeoutError
|
||||
except ImportError:
|
||||
FutureTimeoutError = TimeoutError # type: ignore[misc]
|
||||
|
||||
try:
|
||||
AsyncioTimeoutError = asyncio.exceptions.TimeoutError
|
||||
except AttributeError:
|
||||
AsyncioTimeoutError = TimeoutError # type: ignore[misc]
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import ssl
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..._typing import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT
|
||||
|
||||
|
||||
def _can_shutdown_and_close_selector_loop_bug() -> bool:
|
||||
import platform
|
||||
|
||||
if platform.system() == "Windows" and platform.python_version_tuple()[:2] == (
|
||||
"3",
|
||||
"7",
|
||||
):
|
||||
return int(platform.python_version_tuple()[-1]) >= 17
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Windows + asyncio bug where doing our shutdown procedure induce a crash
|
||||
# in SelectorLoop
|
||||
# File "C:\hostedtoolcache\windows\Python\3.7.9\x64\lib\selectors.py", line 314, in _select
|
||||
# r, w, x = select.select(r, w, w, timeout)
|
||||
# [WinError 10038] An operation was attempted on something that is not a socket
|
||||
_CPYTHON_SELECTOR_CLOSE_BUG_EXIST = _can_shutdown_and_close_selector_loop_bug() is False
|
||||
|
||||
|
||||
class AsyncSocket:
|
||||
"""
|
||||
This class is brought to add a level of abstraction to an asyncio transport (reader, or writer)
|
||||
We don't want to have two distinct code (async/sync) but rather a unified and easily verifiable
|
||||
code base.
|
||||
|
||||
'ssa' stands for Simplified - Socket - Asynchronous.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
family: socket.AddressFamily = socket.AF_INET,
|
||||
type: socket.SocketKind = socket.SOCK_STREAM,
|
||||
proto: int = -1,
|
||||
fileno: int | None = None,
|
||||
) -> None:
|
||||
self.family: socket.AddressFamily = family
|
||||
self.type: socket.SocketKind = type
|
||||
self.proto: int = proto
|
||||
self._fileno: int | None = fileno
|
||||
|
||||
self._connect_called: bool = False
|
||||
self._established: asyncio.Event = asyncio.Event()
|
||||
|
||||
# we do that everytime to forward properly options / advanced settings
|
||||
self._sock: socket.socket = socket.socket(
|
||||
family=self.family, type=self.type, proto=self.proto, fileno=fileno
|
||||
)
|
||||
# set nonblocking / or cause the loop to block with dgram socket...
|
||||
self._sock.settimeout(0)
|
||||
|
||||
self._writer: asyncio.StreamWriter | DatagramWriter | None = None
|
||||
self._reader: asyncio.StreamReader | DatagramReader | None = None
|
||||
|
||||
self._writer_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
self._reader_semaphore: asyncio.Semaphore = asyncio.Semaphore()
|
||||
|
||||
self._addr: tuple[str, int] | tuple[str, int, int, int] | None = None
|
||||
|
||||
self._external_timeout: float | int | None = None
|
||||
self._tls_in_tls = False
|
||||
|
||||
def fileno(self) -> int:
|
||||
return self._fileno if self._fileno is not None else self._sock.fileno()
|
||||
|
||||
async def wait_for_close(self) -> None:
|
||||
if self._connect_called:
|
||||
return
|
||||
|
||||
if self._writer is None:
|
||||
return
|
||||
|
||||
try:
|
||||
# report made in https://github.com/jawah/niquests/issues/184
|
||||
# made us believe that sometime ssl_transport is freed before
|
||||
# getting there. So we could end up there with a half broken
|
||||
# writer state. The original user was using Windows at the time.
|
||||
is_ssl = self._writer.get_extra_info("ssl_object") is not None
|
||||
except AttributeError:
|
||||
is_ssl = False
|
||||
|
||||
if is_ssl:
|
||||
# Give the connection a chance to write any data in the buffer,
|
||||
# and then forcibly tear down the SSL connection.
|
||||
await asyncio.sleep(0)
|
||||
self._writer.transport.abort()
|
||||
|
||||
try:
|
||||
# wait_closed can hang indefinitely!
|
||||
# on Python 3.8 and 3.9
|
||||
# there's some case where Python want an explicit EOT
|
||||
# (spoiler: it was a CPython bug) fixed in recent interpreters.
|
||||
# to circumvent this and still have a proper close
|
||||
# we enforce a maximum delay (1000ms).
|
||||
async with timeout(1):
|
||||
await self._writer.wait_closed()
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer.close()
|
||||
|
||||
edge_case_close_bug_exist = _CPYTHON_SELECTOR_CLOSE_BUG_EXIST
|
||||
|
||||
# Windows + asyncio + asyncio.SelectorEventLoop limits us on how far
|
||||
# we can safely shutdown the socket.
|
||||
if not edge_case_close_bug_exist and platform.system() == "Windows":
|
||||
if hasattr(asyncio, "SelectorEventLoop") and isinstance(
|
||||
asyncio.get_running_loop(), asyncio.SelectorEventLoop
|
||||
):
|
||||
edge_case_close_bug_exist = True
|
||||
|
||||
try:
|
||||
# see https://github.com/MagicStack/uvloop/issues/241
|
||||
# and https://github.com/jawah/niquests/issues/166
|
||||
# probably not just uvloop.
|
||||
uvloop_edge_case_bug = False
|
||||
|
||||
# keep track of our clean exit procedure
|
||||
shutdown_called = False
|
||||
close_called = False
|
||||
|
||||
if hasattr(self._sock, "shutdown"):
|
||||
try:
|
||||
self._sock.shutdown(socket.SHUT_RD)
|
||||
shutdown_called = True
|
||||
except TypeError:
|
||||
uvloop_edge_case_bug = True
|
||||
# uvloop don't support shutdown! and sometime does not support close()...
|
||||
# see https://github.com/jawah/niquests/issues/166 for ctx.
|
||||
try:
|
||||
self._sock.close()
|
||||
close_called = True
|
||||
except TypeError:
|
||||
# last chance of releasing properly the underlying fd!
|
||||
try:
|
||||
direct_sock = socket.socket(fileno=self._sock.fileno())
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
direct_sock.shutdown(socket.SHUT_RD)
|
||||
shutdown_called = True
|
||||
except OSError:
|
||||
warnings.warn(
|
||||
(
|
||||
"urllib3-future is unable to properly close your async socket. "
|
||||
"This mean that you are probably using an asyncio implementation like uvloop "
|
||||
"that does not support shutdown() or/and close() on the socket transport. "
|
||||
"This will lead to unclosed socket (fd)."
|
||||
),
|
||||
ResourceWarning,
|
||||
)
|
||||
finally:
|
||||
direct_sock.detach()
|
||||
# we have to force call close() on our sock object (even after shutdown).
|
||||
# or we'll get a resource warning for sure!
|
||||
if isinstance(self._sock, socket.socket) and hasattr(self._sock, "close"):
|
||||
if not uvloop_edge_case_bug and not edge_case_close_bug_exist:
|
||||
try:
|
||||
self._sock.close()
|
||||
close_called = True
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
if not close_called or not shutdown_called:
|
||||
# this branch detect whether we have an asyncio.TransportSocket instead of socket.socket.
|
||||
if hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
|
||||
try:
|
||||
self._sock._sock.close()
|
||||
except (AttributeError, OSError, TypeError):
|
||||
pass
|
||||
|
||||
except (
|
||||
OSError
|
||||
): # branch where we failed to connect and still try to release resource
|
||||
if isinstance(self._sock, socket.socket):
|
||||
try:
|
||||
self._sock.close() # don't call close on asyncio.TransportSocket
|
||||
except (OSError, TypeError, AttributeError):
|
||||
pass
|
||||
elif hasattr(self._sock, "_sock") and not edge_case_close_bug_exist:
|
||||
try:
|
||||
self._sock._sock.detach()
|
||||
except (AttributeError, OSError, TypeError):
|
||||
pass
|
||||
|
||||
self._connect_called = False
|
||||
self._established.clear()
|
||||
|
||||
async def wait_for_readiness(self) -> None:
|
||||
await self._established.wait()
|
||||
|
||||
def setsockopt(self, __level: int, __optname: int, __value: int | bytes) -> None:
|
||||
self._sock.setsockopt(__level, __optname, __value)
|
||||
|
||||
@typing.overload
|
||||
def getsockopt(self, __level: int, __optname: int) -> int: ...
|
||||
|
||||
@typing.overload
|
||||
def getsockopt(self, __level: int, __optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self, __level: int, __optname: int, buflen: int | None = None
|
||||
) -> int | bytes:
|
||||
if buflen is None:
|
||||
return self._sock.getsockopt(__level, __optname)
|
||||
return self._sock.getsockopt(__level, __optname, buflen)
|
||||
|
||||
def should_connect(self) -> bool:
|
||||
return self._connect_called is False
|
||||
|
||||
async def connect(self, addr: tuple[str, int] | tuple[str, int, int, int]) -> None:
|
||||
if self._connect_called:
|
||||
raise OSError(
|
||||
"attempted to connect twice on a already established connection"
|
||||
)
|
||||
|
||||
self._connect_called = True
|
||||
|
||||
# there's a particularity on Windows
|
||||
# we must not forward non-IP in addr due to
|
||||
# a limitation in the network bridge used in asyncio
|
||||
if platform.system() == "Windows":
|
||||
from ..resolver.utils import is_ipv4, is_ipv6
|
||||
|
||||
host, port = addr[:2]
|
||||
|
||||
if not is_ipv4(host) and not is_ipv6(host):
|
||||
res = await asyncio.get_running_loop().getaddrinfo(
|
||||
host,
|
||||
port,
|
||||
family=self.family,
|
||||
type=self.type,
|
||||
)
|
||||
|
||||
if not res:
|
||||
raise socket.gaierror(f"unable to resolve hostname {host}")
|
||||
|
||||
addr = res[0][-1]
|
||||
|
||||
if self._external_timeout is not None:
|
||||
try:
|
||||
async with timeout(self._external_timeout):
|
||||
await asyncio.get_running_loop().sock_connect(self._sock, addr)
|
||||
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
|
||||
self._connect_called = False
|
||||
raise StandardTimeoutError from e
|
||||
except RuntimeError:
|
||||
raise ConnectionError(
|
||||
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
await asyncio.get_running_loop().sock_connect(self._sock, addr)
|
||||
except RuntimeError: # Defensive: CPython might raise RuntimeError if there is a FD allocation error.
|
||||
raise ConnectionError(
|
||||
"Likely FD Kernel/Loop Racing Allocation Error. You should retry."
|
||||
)
|
||||
|
||||
if self.type == socket.SOCK_STREAM or self.type == -1: # type: ignore[comparison-overlap]
|
||||
self._reader, self._writer = await asyncio.open_connection(sock=self._sock)
|
||||
elif self.type == socket.SOCK_DGRAM:
|
||||
self._reader, self._writer = await open_dgram_connection(sock=self._sock)
|
||||
|
||||
# can become an asyncio.TransportSocket
|
||||
assert self._writer is not None
|
||||
self._sock = self._writer.get_extra_info("socket", self._sock)
|
||||
|
||||
self._addr = addr
|
||||
self._established.set()
|
||||
|
||||
async def wrap_socket(
|
||||
self,
|
||||
ctx: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: str | None = None,
|
||||
ssl_handshake_timeout: float | None = None,
|
||||
) -> SSLAsyncSocket:
|
||||
await self._established.wait()
|
||||
self._established.clear()
|
||||
|
||||
# only if Python <= 3.10
|
||||
try:
|
||||
setattr(
|
||||
asyncio.sslproto._SSLProtocolTransport, # type: ignore[attr-defined]
|
||||
"_start_tls_compatible",
|
||||
True,
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
if self.type == socket.SOCK_STREAM:
|
||||
assert self._writer is not None
|
||||
assert isinstance(self._writer, asyncio.StreamWriter)
|
||||
|
||||
# bellow is hard to maintain. Starting with 3.11+, it is useless.
|
||||
protocol = self._writer._protocol # type: ignore[attr-defined]
|
||||
await self._writer.drain()
|
||||
|
||||
new_transport = await self._writer._loop.start_tls( # type: ignore[attr-defined]
|
||||
self._writer._transport, # type: ignore[attr-defined]
|
||||
protocol,
|
||||
ctx,
|
||||
server_side=False,
|
||||
server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
self._writer._transport = new_transport # type: ignore[attr-defined]
|
||||
|
||||
transport = self._writer.transport
|
||||
protocol._stream_writer = self._writer
|
||||
protocol._transport = transport
|
||||
protocol._over_ssl = transport.get_extra_info("sslcontext") is not None
|
||||
|
||||
self._tls_ctx = ctx
|
||||
else:
|
||||
raise RuntimeError("Unsupported socket type")
|
||||
|
||||
self._established.set()
|
||||
self.__class__ = SSLAsyncSocket
|
||||
|
||||
return self # type: ignore[return-value]
|
||||
|
||||
async def recv(self, size: int = -1) -> bytes | list[bytes]:
|
||||
"""Receive data from the socket.
|
||||
|
||||
Returns ``bytes`` for a single datagram (or stream chunk), or
|
||||
``list[bytes]`` when GRO / batch-receive delivered multiple
|
||||
coalesced datagrams in one syscall. The caller can then feed
|
||||
all segments to the QUIC state-machine in a tight loop before
|
||||
probing, avoiding per-datagram overhead."""
|
||||
if size == -1:
|
||||
size = 65536
|
||||
assert self._reader is not None
|
||||
await self._established.wait()
|
||||
await self._reader_semaphore.acquire()
|
||||
|
||||
try:
|
||||
if self._external_timeout is not None:
|
||||
try:
|
||||
async with timeout(self._external_timeout):
|
||||
return await self._reader.read(n=size)
|
||||
except (FutureTimeoutError, AsyncioTimeoutError, TimeoutError) as e:
|
||||
self._reader_semaphore.release()
|
||||
raise StandardTimeoutError from e
|
||||
except OSError as e: # Defensive: treat any OSError as ConnReset!
|
||||
raise ConnectionResetError() from e
|
||||
return await self._reader.read(n=size)
|
||||
finally:
|
||||
self._reader_semaphore.release()
|
||||
|
||||
async def read_exact(self, size: int = -1) -> bytes | list[bytes]:
|
||||
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
|
||||
return await self.recv(size=size)
|
||||
|
||||
async def read(self) -> bytes | list[bytes]:
|
||||
"""Just an alias for recv(), it is needed due to our custom AsyncSocks override."""
|
||||
return await self.recv()
|
||||
|
||||
async def sendall(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
|
||||
assert self._writer is not None
|
||||
await self._established.wait()
|
||||
await self._writer_semaphore.acquire()
|
||||
try:
|
||||
self._writer.write(data) # type: ignore[arg-type]
|
||||
await self._writer.drain()
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
self._writer_semaphore.release()
|
||||
|
||||
async def write_all(
|
||||
self, data: bytes | bytearray | memoryview | list[bytes]
|
||||
) -> None:
|
||||
"""Just an alias for sendall(), it is needed due to our custom AsyncSocks override."""
|
||||
await self.sendall(data)
|
||||
|
||||
async def send(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
|
||||
await self.sendall(data)
|
||||
|
||||
def settimeout(self, __value: float | None = None) -> None:
|
||||
self._external_timeout = __value
|
||||
|
||||
def gettimeout(self) -> float | None:
|
||||
return self._external_timeout
|
||||
|
||||
def getpeername(self) -> tuple[str, int]:
|
||||
return self._sock.getpeername() # type: ignore[no-any-return]
|
||||
|
||||
def bind(self, addr: tuple[str, int]) -> None:
|
||||
self._sock.bind(addr)
|
||||
|
||||
|
||||
class SSLAsyncSocket(AsyncSocket):
|
||||
_tls_ctx: ssl.SSLContext
|
||||
_tls_in_tls: bool
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(
|
||||
self, binary_form: Literal[False] = ...
|
||||
) -> _TYPE_PEER_CERT_RET_DICT | None: ...
|
||||
|
||||
@typing.overload
|
||||
def getpeercert(self, binary_form: Literal[True]) -> bytes | None: ...
|
||||
|
||||
def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET:
|
||||
return self.sslobj.getpeercert(binary_form=binary_form) # type: ignore[return-value]
|
||||
|
||||
def selected_alpn_protocol(self) -> str | None:
|
||||
return self.sslobj.selected_alpn_protocol()
|
||||
|
||||
@property
|
||||
def sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
|
||||
if self._writer is not None:
|
||||
sslobj: ssl.SSLSocket | ssl.SSLObject | None = self._writer.get_extra_info(
|
||||
"ssl_object"
|
||||
)
|
||||
|
||||
if sslobj is not None:
|
||||
return sslobj
|
||||
|
||||
raise RuntimeError(
|
||||
'"ssl_object" could not be extracted from this SslAsyncSock instance'
|
||||
)
|
||||
|
||||
def version(self) -> str | None:
|
||||
return self.sslobj.version()
|
||||
|
||||
@property
|
||||
def context(self) -> ssl.SSLContext:
|
||||
return self.sslobj.context
|
||||
|
||||
@property
|
||||
def _sslobj(self) -> ssl.SSLSocket | ssl.SSLObject:
|
||||
return self.sslobj
|
||||
|
||||
def cipher(self) -> tuple[str, str, int] | None:
|
||||
return self.sslobj.cipher()
|
||||
|
||||
async def wrap_socket(
|
||||
self,
|
||||
ctx: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: str | None = None,
|
||||
ssl_handshake_timeout: float | None = None,
|
||||
) -> SSLAsyncSocket:
|
||||
self._tls_in_tls = True
|
||||
|
||||
return await super().wrap_socket(
|
||||
ctx,
|
||||
server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _has_complete_support_dgram() -> bool:
|
||||
"""A bug exist in PyPy asyncio implementation that prevent us to use a DGRAM socket.
|
||||
This piece of code inform us, potentially, if PyPy has fixed the winapi implementation.
|
||||
See https://github.com/pypy/pypy/issues/4008 and https://github.com/jawah/niquests/pull/87
|
||||
|
||||
The stacktrace look as follows:
|
||||
File "C:\\hostedtoolcache\\windows\\PyPy\3.10.13\x86\\Lib\asyncio\\windows_events.py", line 594, in connect
|
||||
_overlapped.WSAConnect(conn.fileno(), address)
|
||||
AttributeError: module '_overlapped' has no attribute 'WSAConnect'
|
||||
"""
|
||||
import platform
|
||||
|
||||
if platform.system() == "Windows" and platform.python_implementation() == "PyPy":
|
||||
try:
|
||||
import _overlapped # type: ignore[import-not-found]
|
||||
except ImportError: # Defensive:
|
||||
return False
|
||||
|
||||
if hasattr(_overlapped, "WSAConnect"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
__all__ = (
|
||||
"AsyncSocket",
|
||||
"SSLAsyncSocket",
|
||||
"_has_complete_support_dgram",
|
||||
)
|
||||
@@ -0,0 +1,640 @@
|
||||
"""
|
||||
High-performance asyncio DatagramTransport with Linux-specific UDP
|
||||
receive/send coalescing:
|
||||
|
||||
- GRO (receive): ``setsockopt(SOL_UDP, UDP_GRO)`` + ``recvmsg`` cmsg
|
||||
- GSO (send): ``sendmsg`` with ``UDP_SEGMENT`` cmsg
|
||||
|
||||
All other platforms fall back to the standard asyncio DatagramTransport.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import socket
|
||||
import struct
|
||||
from collections import deque
|
||||
from typing import Any, Callable
|
||||
|
||||
from ..._constant import UDP_LINUX_GRO, UDP_LINUX_SEGMENT
|
||||
|
||||
_UINT16 = struct.Struct("=H")
|
||||
|
||||
_DEFAULT_GRO_BUF = 65535
|
||||
|
||||
# Flow control watermarks for the custom write queue
|
||||
_HIGH_WATERMARK = 64 * 1024
|
||||
_LOW_WATERMARK = 16 * 1024
|
||||
|
||||
# GSO kernel limit: max segments per sendmsg call
|
||||
_GSO_MAX_SEGMENTS = 64
|
||||
|
||||
|
||||
def _sock_has_gro(sock: socket.socket) -> bool:
|
||||
"""Check if GRO is enabled on *sock* (caller must have set it)."""
|
||||
try:
|
||||
return sock.getsockopt(socket.SOL_UDP, UDP_LINUX_GRO) == 1
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _sock_has_gso(sock: socket.socket) -> bool:
|
||||
"""Check if the kernel supports GSO on *sock*."""
|
||||
try:
|
||||
sock.getsockopt(socket.SOL_UDP, UDP_LINUX_SEGMENT)
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _split_gro_buffer(buf: bytes, segment_size: int) -> list[bytes]:
|
||||
if segment_size <= 0 or len(buf) <= segment_size:
|
||||
return [buf]
|
||||
segments = []
|
||||
mv = memoryview(buf)
|
||||
for offset in range(0, len(buf), segment_size):
|
||||
segments.append(bytes(mv[offset : offset + segment_size]))
|
||||
return segments
|
||||
|
||||
|
||||
def _group_by_segment_size(datagrams: list[bytes]) -> list[tuple[int, list[bytes]]]:
|
||||
"""Group consecutive same-size datagrams for Linux UDP GSO.
|
||||
|
||||
GSO requires all segments to be the same size (except the last,
|
||||
which may be shorter). Max 64 segments per ``sendmsg`` call."""
|
||||
if not datagrams:
|
||||
return []
|
||||
groups: list[tuple[int, list[bytes]]] = []
|
||||
current_size = len(datagrams[0])
|
||||
current_group: list[bytes] = [datagrams[0]]
|
||||
for dgram in datagrams[1:]:
|
||||
if len(dgram) == current_size and len(current_group) < _GSO_MAX_SEGMENTS:
|
||||
current_group.append(dgram)
|
||||
else:
|
||||
groups.append((current_size, current_group))
|
||||
current_size = len(dgram)
|
||||
current_group = [dgram]
|
||||
groups.append((current_size, current_group))
|
||||
return groups
|
||||
|
||||
|
||||
def sync_recv_gro(
|
||||
sock: socket.socket, bufsize: int, gro_segment_size: int = 1280
|
||||
) -> bytes | list[bytes]:
|
||||
"""Blocking recvmsg with GRO cmsg parsing. Returns bytes or list[bytes]."""
|
||||
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
|
||||
|
||||
data, ancdata, _flags, addr = sock.recvmsg(bufsize, ancbufsize)
|
||||
|
||||
if not data:
|
||||
return b""
|
||||
|
||||
segment_size = gro_segment_size
|
||||
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
|
||||
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
|
||||
break
|
||||
|
||||
if len(data) <= segment_size:
|
||||
return data
|
||||
|
||||
return _split_gro_buffer(data, segment_size)
|
||||
|
||||
|
||||
def sync_sendmsg_gso(sock: socket.socket, datagrams: list[bytes]) -> None:
|
||||
"""Batch-send datagrams using GSO. Falls back to individual sends."""
|
||||
for segment_size, group in _group_by_segment_size(datagrams):
|
||||
if len(group) == 1:
|
||||
sock.sendall(group[0])
|
||||
continue
|
||||
buf = b"".join(group)
|
||||
sock.sendmsg(
|
||||
[buf],
|
||||
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
|
||||
)
|
||||
|
||||
|
||||
class _OptimizedDatagramTransport(asyncio.DatagramTransport):
|
||||
__slots__ = (
|
||||
"_loop",
|
||||
"_sock",
|
||||
"_protocol",
|
||||
"_address",
|
||||
"_gro_enabled",
|
||||
"_gso_enabled",
|
||||
"_gro_segment_size",
|
||||
"_recv_buf_size",
|
||||
"_closing",
|
||||
"_closed_fut",
|
||||
"_extra",
|
||||
"_paused",
|
||||
"_write_ready",
|
||||
"_send_queue",
|
||||
"_buffer_size",
|
||||
"_protocol_paused",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
sock: socket.socket,
|
||||
protocol: asyncio.DatagramProtocol,
|
||||
address: tuple[str, int] | None,
|
||||
gro_enabled: bool,
|
||||
gso_enabled: bool,
|
||||
gro_segment_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._loop = loop
|
||||
self._sock = sock
|
||||
self._protocol = protocol
|
||||
self._address = address
|
||||
self._gro_enabled = gro_enabled
|
||||
self._gso_enabled = gso_enabled
|
||||
self._gro_segment_size = gro_segment_size
|
||||
self._closing = False
|
||||
self._closed_fut: asyncio.Future[None] = loop.create_future()
|
||||
self._paused = False
|
||||
self._write_ready = True
|
||||
|
||||
# Write buffer state
|
||||
self._send_queue: deque[tuple[bytes, tuple[str, int] | None]] = (
|
||||
collections.deque()
|
||||
)
|
||||
self._buffer_size = 0
|
||||
self._protocol_paused = False
|
||||
|
||||
self._recv_buf_size = _DEFAULT_GRO_BUF if gro_enabled else gro_segment_size
|
||||
|
||||
self._extra = {
|
||||
"peername": address,
|
||||
"socket": sock,
|
||||
"sockname": sock.getsockname(),
|
||||
}
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
return self._extra.get(name, default)
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
return self._closing
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
self._closing = True
|
||||
self._loop.remove_reader(self._sock.fileno())
|
||||
# Drain the write queue gracefully in the background
|
||||
if not self._send_queue:
|
||||
self._loop.call_soon(self._call_connection_lost, None)
|
||||
|
||||
def abort(self) -> None:
|
||||
self._closing = True
|
||||
self._call_connection_lost(None)
|
||||
|
||||
def _call_connection_lost(self, exc: Exception | None) -> None:
|
||||
try:
|
||||
self._loop.remove_reader(self._sock.fileno())
|
||||
self._loop.remove_writer(self._sock.fileno())
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._protocol.connection_lost(exc)
|
||||
finally:
|
||||
self._sock.close()
|
||||
if not self._closed_fut.done():
|
||||
self._closed_fut.set_result(None)
|
||||
|
||||
def sendto(self, data: bytes, addr: tuple[str, int] | None = None) -> None: # type: ignore[override]
|
||||
if self._closing:
|
||||
raise OSError("Transport is closing")
|
||||
|
||||
target = addr or self._address
|
||||
if not self._write_ready:
|
||||
self._queue_write(data, target)
|
||||
return
|
||||
|
||||
try:
|
||||
if target is not None:
|
||||
self._sock.sendto(data, target)
|
||||
else:
|
||||
self._sock.send(data)
|
||||
except BlockingIOError:
|
||||
self._write_ready = False
|
||||
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
|
||||
self._queue_write(data, target)
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
|
||||
def sendto_many(self, datagrams: list[bytes]) -> None:
|
||||
"""Send multiple datagrams, using GSO when available.
|
||||
|
||||
Falls back to individual ``sendto`` calls when GSO is not
|
||||
supported or the socket write buffer is full."""
|
||||
if self._closing:
|
||||
raise OSError("Transport is closing")
|
||||
|
||||
if not self._write_ready:
|
||||
target = self._address
|
||||
for dgram in datagrams:
|
||||
self._queue_write(dgram, target)
|
||||
return
|
||||
|
||||
if self._gso_enabled:
|
||||
self._send_linux_gso(datagrams)
|
||||
else:
|
||||
for dgram in datagrams:
|
||||
self.sendto(dgram)
|
||||
|
||||
def _send_linux_gso(self, datagrams: list[bytes]) -> None:
|
||||
for segment_size, group in _group_by_segment_size(datagrams):
|
||||
if len(group) == 1:
|
||||
# Single datagram — plain send (GSO needs >1 segment)
|
||||
try:
|
||||
self._sock.send(group[0])
|
||||
except BlockingIOError:
|
||||
self._write_ready = False
|
||||
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
|
||||
self._queue_write(group[0], self._address)
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
continue
|
||||
|
||||
buf = b"".join(group)
|
||||
try:
|
||||
self._sock.sendmsg(
|
||||
[buf],
|
||||
[(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))],
|
||||
)
|
||||
except BlockingIOError:
|
||||
self._write_ready = False
|
||||
self._loop.add_writer(self._sock.fileno(), self._on_write_ready)
|
||||
# Queue individual datagrams as fallback
|
||||
for dgram in group:
|
||||
self._queue_write(dgram, self._address)
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
|
||||
def _queue_write(self, data: bytes, addr: tuple[str, int] | None) -> None:
|
||||
self._send_queue.append((data, addr))
|
||||
self._buffer_size += len(data)
|
||||
self._maybe_pause_protocol()
|
||||
|
||||
def _maybe_pause_protocol(self) -> None:
|
||||
if self._buffer_size >= _HIGH_WATERMARK and not self._protocol_paused:
|
||||
self._protocol_paused = True
|
||||
try:
|
||||
self._protocol.pause_writing()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _maybe_resume_protocol(self) -> None:
|
||||
if self._protocol_paused and self._buffer_size <= _LOW_WATERMARK:
|
||||
self._protocol_paused = False
|
||||
try:
|
||||
self._protocol.resume_writing()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def _on_write_ready(self) -> None:
|
||||
while self._send_queue:
|
||||
data, addr = self._send_queue[0]
|
||||
try:
|
||||
if addr is not None:
|
||||
self._sock.sendto(data, addr)
|
||||
else:
|
||||
self._sock.send(data)
|
||||
except BlockingIOError:
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
|
||||
self._send_queue.popleft()
|
||||
self._buffer_size -= len(data)
|
||||
|
||||
self._maybe_resume_protocol()
|
||||
self._write_ready = True
|
||||
self._loop.remove_writer(self._sock.fileno())
|
||||
|
||||
if self._closing:
|
||||
self._call_connection_lost(None)
|
||||
|
||||
def pause_reading(self) -> None:
|
||||
if not self._paused:
|
||||
self._paused = True
|
||||
self._loop.remove_reader(self._sock.fileno())
|
||||
|
||||
def resume_reading(self) -> None:
|
||||
if self._paused:
|
||||
self._paused = False
|
||||
self._loop.add_reader(self._sock.fileno(), self._on_readable)
|
||||
|
||||
def _start(self) -> None:
|
||||
self._loop.call_soon(self._protocol.connection_made, self)
|
||||
self._loop.add_reader(self._sock.fileno(), self._on_readable)
|
||||
|
||||
def _on_readable(self) -> None:
|
||||
if self._closing:
|
||||
return
|
||||
self._recv_linux_gro()
|
||||
|
||||
def _recv_linux_gro(self) -> None:
|
||||
ancbufsize = socket.CMSG_SPACE(_UINT16.size)
|
||||
while True:
|
||||
try:
|
||||
data, ancdata, _flags, addr = self._sock.recvmsg(
|
||||
self._recv_buf_size, ancbufsize
|
||||
)
|
||||
except BlockingIOError:
|
||||
return
|
||||
except OSError as exc:
|
||||
self._protocol.error_received(exc)
|
||||
return
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
segment_size = self._gro_segment_size
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO:
|
||||
(segment_size,) = _UINT16.unpack(cmsg_data[:2])
|
||||
break
|
||||
|
||||
if len(data) <= segment_size:
|
||||
self._protocol.datagram_received(data, addr)
|
||||
else:
|
||||
segments = _split_gro_buffer(data, segment_size)
|
||||
self._protocol.datagrams_received(segments, addr) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def create_udp_endpoint(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
protocol_factory: Callable[[], asyncio.DatagramProtocol],
|
||||
*,
|
||||
local_addr: tuple[str, int] | None = None,
|
||||
remote_addr: tuple[str, int] | None = None,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
reuse_port: bool = False,
|
||||
gro_segment_size: int = 1280,
|
||||
sock: socket.socket | None = None,
|
||||
) -> tuple[asyncio.DatagramTransport, asyncio.DatagramProtocol]:
|
||||
if sock is not None:
|
||||
# Caller provided a pre-connected socket — skip creation/bind/connect.
|
||||
try:
|
||||
connected_addr = sock.getpeername()
|
||||
except OSError:
|
||||
connected_addr = None
|
||||
else:
|
||||
# 1. Resolve Addresses
|
||||
if family == socket.AF_UNSPEC:
|
||||
target_addr = local_addr or remote_addr
|
||||
if target_addr:
|
||||
infos = await loop.getaddrinfo(
|
||||
target_addr[0], target_addr[1], type=socket.SOCK_DGRAM
|
||||
)
|
||||
family = infos[0][0]
|
||||
else:
|
||||
family = socket.AF_INET
|
||||
|
||||
# 2. Create Socket
|
||||
sock = socket.socket(family, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
|
||||
if reuse_port:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
if local_addr:
|
||||
sock.bind(local_addr)
|
||||
|
||||
connected_addr = None
|
||||
|
||||
if remote_addr:
|
||||
await loop.sock_connect(sock, remote_addr)
|
||||
connected_addr = remote_addr
|
||||
|
||||
# 3. Determine capabilities — the caller is responsible for
|
||||
# enabling GRO via setsockopt before handing us the socket.
|
||||
gro_enabled = _sock_has_gro(sock)
|
||||
gso_enabled = _sock_has_gso(sock)
|
||||
|
||||
if not gro_enabled and not gso_enabled:
|
||||
return await loop.create_datagram_endpoint(
|
||||
lambda: protocol_factory(), sock=sock
|
||||
)
|
||||
|
||||
# 4. Wire up optimized transport
|
||||
protocol = protocol_factory()
|
||||
|
||||
transport = _OptimizedDatagramTransport(
|
||||
loop=loop,
|
||||
sock=sock,
|
||||
protocol=protocol,
|
||||
address=connected_addr,
|
||||
gro_enabled=gro_enabled,
|
||||
gso_enabled=gso_enabled,
|
||||
gro_segment_size=gro_segment_size,
|
||||
)
|
||||
|
||||
transport._start()
|
||||
|
||||
return transport, protocol
|
||||
|
||||
|
||||
class DatagramReader:
|
||||
"""API-compatible with ``asyncio.StreamReader`` (duck-typed) so that
|
||||
``AsyncSocket`` can assign an instance to ``self._reader`` and the
|
||||
existing ``recv()`` code works unchanged.
|
||||
|
||||
When GRO delivers multiple coalesced segments in a single syscall,
|
||||
``feed_datagrams()`` stores them as a single ``list[bytes]`` entry.
|
||||
``read()`` then returns that list directly so the caller can feed
|
||||
all segments to the QUIC state-machine in one pass before probing —
|
||||
avoiding the per-datagram recv→feed→probe round-trip overhead."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer: deque[bytes | list[bytes]] = collections.deque()
|
||||
self._waiter: asyncio.Future[None] | None = None
|
||||
self._exception: BaseException | None = None
|
||||
self._eof = False
|
||||
|
||||
def feed_datagram(self, data: bytes, addr: Any) -> None:
|
||||
"""Feed a single (non-coalesced) datagram."""
|
||||
self._buffer.append(data)
|
||||
self._wake_waiter()
|
||||
|
||||
def feed_datagrams(self, data: list[bytes], addr: Any) -> None:
|
||||
"""Feed a batch of coalesced datagrams as a single entry."""
|
||||
self._buffer.append(data)
|
||||
self._wake_waiter()
|
||||
|
||||
def set_exception(self, exc: BaseException) -> None:
|
||||
self._exception = exc
|
||||
self._wake_waiter()
|
||||
|
||||
def connection_lost(self, exc: BaseException | None) -> None:
|
||||
self._eof = True
|
||||
if exc is not None:
|
||||
self._exception = exc
|
||||
self._wake_waiter()
|
||||
|
||||
def _wake_waiter(self) -> None:
|
||||
waiter = self._waiter
|
||||
if waiter is not None and not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes | list[bytes]:
|
||||
"""Return the next entry from the buffer.
|
||||
|
||||
* ``bytes`` — a single datagram (non-coalesced).
|
||||
* ``list[bytes]`` — a batch of coalesced datagrams from one
|
||||
GRO syscall.
|
||||
* ``b""`` — EOF.
|
||||
"""
|
||||
if self._buffer:
|
||||
return self._buffer.popleft()
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
if self._eof:
|
||||
return b""
|
||||
|
||||
self._waiter = asyncio.get_running_loop().create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
finally:
|
||||
self._waiter = None
|
||||
|
||||
if self._buffer:
|
||||
return self._buffer.popleft()
|
||||
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
|
||||
return b""
|
||||
|
||||
|
||||
class DatagramWriter:
|
||||
"""API-compatible with ``asyncio.StreamWriter`` (duck-typed) so that
|
||||
``AsyncSocket`` can assign an instance to ``self._writer`` and the
|
||||
existing ``sendall()``, ``close()``, ``wait_for_close()`` code works
|
||||
unchanged."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: asyncio.DatagramTransport,
|
||||
) -> None:
|
||||
self._transport = transport
|
||||
self._address: tuple[str, int] | None = transport.get_extra_info("peername")
|
||||
self._closed_event = asyncio.Event()
|
||||
self._paused = False
|
||||
self._drain_waiter: asyncio.Future[None] | None = None
|
||||
|
||||
@property
|
||||
def transport(self) -> asyncio.DatagramTransport:
|
||||
return self._transport
|
||||
|
||||
def write(self, data: bytes | bytearray | memoryview | list[bytes]) -> None:
|
||||
if self._transport.is_closing():
|
||||
return
|
||||
if isinstance(data, list):
|
||||
if hasattr(self._transport, "sendto_many"):
|
||||
self._transport.sendto_many(data)
|
||||
else:
|
||||
# Plain asyncio transport — send individually
|
||||
for dgram in data:
|
||||
self._transport.sendto(dgram, self._address)
|
||||
else:
|
||||
self._transport.sendto(bytes(data), self._address)
|
||||
|
||||
async def drain(self) -> None:
|
||||
if not self._paused:
|
||||
return
|
||||
self._drain_waiter = asyncio.get_running_loop().create_future()
|
||||
try:
|
||||
await self._drain_waiter
|
||||
finally:
|
||||
self._drain_waiter = None
|
||||
|
||||
def close(self) -> None:
|
||||
self._transport.close()
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed_event.wait()
|
||||
|
||||
def get_extra_info(self, name: str, default: Any = None) -> Any:
|
||||
return self._transport.get_extra_info(name, default)
|
||||
|
||||
def _pause_writing(self) -> None:
|
||||
self._paused = True
|
||||
|
||||
def _resume_writing(self) -> None:
|
||||
self._paused = False
|
||||
waiter = self._drain_waiter
|
||||
if waiter is not None and not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
|
||||
class _DatagramBridgeProtocol(asyncio.DatagramProtocol):
|
||||
"""Bridges ``asyncio.DatagramProtocol`` callbacks to
|
||||
``DatagramReader`` / ``DatagramWriter``."""
|
||||
|
||||
def __init__(self, reader: DatagramReader) -> None:
|
||||
self._reader = reader
|
||||
self._writer: DatagramWriter | None = None
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
pass # transport is already wired via DatagramWriter
|
||||
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
self._reader.feed_datagram(data, addr)
|
||||
|
||||
def datagrams_received(self, data: list[bytes], addr: tuple[str, int]) -> None:
|
||||
self._reader.feed_datagrams(data, addr)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
self._reader.set_exception(exc)
|
||||
|
||||
def connection_lost(self, exc: BaseException | None) -> None:
|
||||
self._reader.connection_lost(exc)
|
||||
if self._writer is not None:
|
||||
self._writer._closed_event.set()
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer._pause_writing()
|
||||
|
||||
def resume_writing(self) -> None:
|
||||
if self._writer is not None:
|
||||
self._writer._resume_writing()
|
||||
|
||||
|
||||
async def open_dgram_connection(
|
||||
remote_addr: tuple[str, int] | None = None,
|
||||
*,
|
||||
local_addr: tuple[str, int] | None = None,
|
||||
family: int = socket.AF_UNSPEC,
|
||||
sock: socket.socket | None = None,
|
||||
gro_segment_size: int = 1280,
|
||||
) -> tuple[DatagramReader, DatagramWriter]:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
reader = DatagramReader()
|
||||
protocol = _DatagramBridgeProtocol(reader)
|
||||
|
||||
transport, _ = await create_udp_endpoint(
|
||||
loop,
|
||||
lambda: protocol,
|
||||
local_addr=local_addr,
|
||||
remote_addr=remote_addr,
|
||||
family=family,
|
||||
gro_segment_size=gro_segment_size,
|
||||
sock=sock,
|
||||
)
|
||||
|
||||
writer = DatagramWriter(transport)
|
||||
protocol._writer = writer
|
||||
|
||||
return reader, writer
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from asyncio import CancelledError, events, tasks
|
||||
from types import TracebackType
|
||||
|
||||
__all__ = (
|
||||
"Timeout",
|
||||
"timeout",
|
||||
)
|
||||
|
||||
|
||||
class _State(enum.Enum):
|
||||
CREATED = "created"
|
||||
ENTERED = "active"
|
||||
EXPIRING = "expiring"
|
||||
EXPIRED = "expired"
|
||||
EXITED = "finished"
|
||||
|
||||
|
||||
class Timeout:
|
||||
"""Asynchronous context manager for cancelling overdue coroutines.
|
||||
|
||||
Use `timeout()` or `timeout_at()` rather than instantiating this class directly.
|
||||
"""
|
||||
|
||||
def __init__(self, when: float | None) -> None:
|
||||
"""Schedule a timeout that will trigger at a given loop time.
|
||||
|
||||
- If `when` is `None`, the timeout will never trigger.
|
||||
- If `when < loop.time()`, the timeout will trigger on the next
|
||||
iteration of the event loop.
|
||||
"""
|
||||
self._state = _State.CREATED
|
||||
|
||||
self._timeout_handler: events.TimerHandle | events.Handle | None = None
|
||||
self._task: tasks.Task | None = None # type: ignore[type-arg]
|
||||
self._when = when
|
||||
|
||||
def when(self) -> float | None:
|
||||
"""Return the current deadline."""
|
||||
return self._when
|
||||
|
||||
def reschedule(self, when: float | None) -> None:
|
||||
"""Reschedule the timeout."""
|
||||
if self._state is not _State.ENTERED:
|
||||
if self._state is _State.CREATED:
|
||||
raise RuntimeError("Timeout has not been entered")
|
||||
raise RuntimeError(
|
||||
f"Cannot change state of {self._state.value} Timeout",
|
||||
)
|
||||
|
||||
self._when = when
|
||||
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
if when is None:
|
||||
self._timeout_handler = None
|
||||
else:
|
||||
loop = events.get_running_loop()
|
||||
if when <= loop.time():
|
||||
self._timeout_handler = loop.call_soon(self._on_timeout)
|
||||
else:
|
||||
self._timeout_handler = loop.call_at(when, self._on_timeout)
|
||||
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state in (_State.EXPIRING, _State.EXPIRED)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
info = [""]
|
||||
if self._state is _State.ENTERED:
|
||||
when = round(self._when, 3) if self._when is not None else None
|
||||
info.append(f"when={when}")
|
||||
info_str = " ".join(info)
|
||||
return f"<Timeout [{self._state.value}]{info_str}>"
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
if self._state is not _State.CREATED:
|
||||
raise RuntimeError("Timeout has already been entered")
|
||||
task = tasks.current_task()
|
||||
if task is None:
|
||||
raise RuntimeError("Timeout should be used inside a task")
|
||||
self._state = _State.ENTERED
|
||||
self._task = task
|
||||
self.reschedule(self._when)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
assert self._state in (_State.ENTERED, _State.EXPIRING)
|
||||
assert self._task is not None
|
||||
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
if self._state is _State.EXPIRING:
|
||||
self._state = _State.EXPIRED
|
||||
|
||||
if exc_type is CancelledError:
|
||||
# Since there are no new cancel requests, we're
|
||||
# handling this.
|
||||
raise TimeoutError from exc_val
|
||||
elif self._state is _State.ENTERED:
|
||||
self._state = _State.EXITED
|
||||
|
||||
return None
|
||||
|
||||
def _on_timeout(self) -> None:
|
||||
assert self._state is _State.ENTERED
|
||||
assert self._task is not None
|
||||
|
||||
self._task.cancel()
|
||||
self._state = _State.EXPIRING
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
||||
|
||||
|
||||
def timeout(delay: float | None) -> Timeout:
|
||||
"""Timeout async context manager.
|
||||
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
|
||||
>>> async with asyncio.timeout(10): # 10 seconds timeout
|
||||
... await long_running_task()
|
||||
|
||||
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
|
||||
long_running_task() is interrupted by raising asyncio.CancelledError,
|
||||
the top-most affected timeout() context manager converts CancelledError
|
||||
into TimeoutError.
|
||||
"""
|
||||
loop = events.get_running_loop()
|
||||
return Timeout(loop.time() + delay if delay is not None else None)
|
||||
Reference in New Issue
Block a user