""" High-performance asyncio DatagramTransport with Linux-specific UDP receive/send coalescing: - GRO (receive): ``setsockopt(SOL_UDP, UDP_GRO)`` + ``recvmsg`` cmsg - GSO (send): ``sendmsg`` with ``UDP_SEGMENT`` cmsg All other platforms fall back to the standard asyncio DatagramTransport. """ from __future__ import annotations import asyncio import collections import socket import struct from collections import deque from typing import Any, Callable from ..._constant import UDP_LINUX_GRO, UDP_LINUX_SEGMENT _UINT16 = struct.Struct("=H") _DEFAULT_GRO_BUF = 65535 # Flow control watermarks for the custom write queue _HIGH_WATERMARK = 64 * 1024 _LOW_WATERMARK = 16 * 1024 # GSO kernel limit: max segments per sendmsg call _GSO_MAX_SEGMENTS = 64 def _sock_has_gro(sock: socket.socket) -> bool: """Check if GRO is enabled on *sock* (caller must have set it).""" try: return sock.getsockopt(socket.SOL_UDP, UDP_LINUX_GRO) == 1 except OSError: return False def _sock_has_gso(sock: socket.socket) -> bool: """Check if the kernel supports GSO on *sock*.""" try: sock.getsockopt(socket.SOL_UDP, UDP_LINUX_SEGMENT) return True except OSError: return False def _split_gro_buffer(buf: bytes, segment_size: int) -> list[bytes]: if segment_size <= 0 or len(buf) <= segment_size: return [buf] segments = [] mv = memoryview(buf) for offset in range(0, len(buf), segment_size): segments.append(bytes(mv[offset : offset + segment_size])) return segments def _group_by_segment_size(datagrams: list[bytes]) -> list[tuple[int, list[bytes]]]: """Group consecutive same-size datagrams for Linux UDP GSO. GSO requires all segments to be the same size (except the last, which may be shorter). Max 64 segments per ``sendmsg`` call.""" if not datagrams: return [] groups: list[tuple[int, list[bytes]]] = [] current_size = len(datagrams[0]) current_group: list[bytes] = [datagrams[0]] for dgram in datagrams[1:]: if len(dgram) == current_size and len(current_group) < _GSO_MAX_SEGMENTS: current_group.append(dgram) else: groups.append((current_size, current_group)) current_size = len(dgram) current_group = [dgram] groups.append((current_size, current_group)) return groups def sync_recv_gro( sock: socket.socket, bufsize: int, gro_segment_size: int = 1280 ) -> bytes | list[bytes]: """Blocking recvmsg with GRO cmsg parsing. Returns bytes or list[bytes].""" ancbufsize = socket.CMSG_SPACE(_UINT16.size) data, ancdata, _flags, addr = sock.recvmsg(bufsize, ancbufsize) if not data: return b"" segment_size = gro_segment_size for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO: (segment_size,) = _UINT16.unpack(cmsg_data[:2]) break if len(data) <= segment_size: return data return _split_gro_buffer(data, segment_size) def sync_sendmsg_gso(sock: socket.socket, datagrams: list[bytes]) -> None: """Batch-send datagrams using GSO. Falls back to individual sends.""" for segment_size, group in _group_by_segment_size(datagrams): if len(group) == 1: sock.sendall(group[0]) continue buf = b"".join(group) sock.sendmsg( [buf], [(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))], ) class _OptimizedDatagramTransport(asyncio.DatagramTransport): __slots__ = ( "_loop", "_sock", "_protocol", "_address", "_gro_enabled", "_gso_enabled", "_gro_segment_size", "_recv_buf_size", "_closing", "_closed_fut", "_extra", "_paused", "_write_ready", "_send_queue", "_buffer_size", "_protocol_paused", ) def __init__( self, loop: asyncio.AbstractEventLoop, sock: socket.socket, protocol: asyncio.DatagramProtocol, address: tuple[str, int] | None, gro_enabled: bool, gso_enabled: bool, gro_segment_size: int, ) -> None: super().__init__() self._loop = loop self._sock = sock self._protocol = protocol self._address = address self._gro_enabled = gro_enabled self._gso_enabled = gso_enabled self._gro_segment_size = gro_segment_size self._closing = False self._closed_fut: asyncio.Future[None] = loop.create_future() self._paused = False self._write_ready = True # Write buffer state self._send_queue: deque[tuple[bytes, tuple[str, int] | None]] = ( collections.deque() ) self._buffer_size = 0 self._protocol_paused = False self._recv_buf_size = _DEFAULT_GRO_BUF if gro_enabled else gro_segment_size self._extra = { "peername": address, "socket": sock, "sockname": sock.getsockname(), } def get_extra_info(self, name: str, default: Any = None) -> Any: return self._extra.get(name, default) def is_closing(self) -> bool: return self._closing def close(self) -> None: if self._closing: return self._closing = True self._loop.remove_reader(self._sock.fileno()) # Drain the write queue gracefully in the background if not self._send_queue: self._loop.call_soon(self._call_connection_lost, None) def abort(self) -> None: self._closing = True self._call_connection_lost(None) def _call_connection_lost(self, exc: Exception | None) -> None: try: self._loop.remove_reader(self._sock.fileno()) self._loop.remove_writer(self._sock.fileno()) except Exception: pass try: self._protocol.connection_lost(exc) finally: self._sock.close() if not self._closed_fut.done(): self._closed_fut.set_result(None) def sendto(self, data: bytes, addr: tuple[str, int] | None = None) -> None: # type: ignore[override] if self._closing: raise OSError("Transport is closing") target = addr or self._address if not self._write_ready: self._queue_write(data, target) return try: if target is not None: self._sock.sendto(data, target) else: self._sock.send(data) except BlockingIOError: self._write_ready = False self._loop.add_writer(self._sock.fileno(), self._on_write_ready) self._queue_write(data, target) except OSError as exc: self._protocol.error_received(exc) def sendto_many(self, datagrams: list[bytes]) -> None: """Send multiple datagrams, using GSO when available. Falls back to individual ``sendto`` calls when GSO is not supported or the socket write buffer is full.""" if self._closing: raise OSError("Transport is closing") if not self._write_ready: target = self._address for dgram in datagrams: self._queue_write(dgram, target) return if self._gso_enabled: self._send_linux_gso(datagrams) else: for dgram in datagrams: self.sendto(dgram) def _send_linux_gso(self, datagrams: list[bytes]) -> None: for segment_size, group in _group_by_segment_size(datagrams): if len(group) == 1: # Single datagram — plain send (GSO needs >1 segment) try: self._sock.send(group[0]) except BlockingIOError: self._write_ready = False self._loop.add_writer(self._sock.fileno(), self._on_write_ready) self._queue_write(group[0], self._address) return except OSError as exc: self._protocol.error_received(exc) continue buf = b"".join(group) try: self._sock.sendmsg( [buf], [(socket.SOL_UDP, UDP_LINUX_SEGMENT, _UINT16.pack(segment_size))], ) except BlockingIOError: self._write_ready = False self._loop.add_writer(self._sock.fileno(), self._on_write_ready) # Queue individual datagrams as fallback for dgram in group: self._queue_write(dgram, self._address) return except OSError as exc: self._protocol.error_received(exc) def _queue_write(self, data: bytes, addr: tuple[str, int] | None) -> None: self._send_queue.append((data, addr)) self._buffer_size += len(data) self._maybe_pause_protocol() def _maybe_pause_protocol(self) -> None: if self._buffer_size >= _HIGH_WATERMARK and not self._protocol_paused: self._protocol_paused = True try: self._protocol.pause_writing() except AttributeError: pass def _maybe_resume_protocol(self) -> None: if self._protocol_paused and self._buffer_size <= _LOW_WATERMARK: self._protocol_paused = False try: self._protocol.resume_writing() except AttributeError: pass def _on_write_ready(self) -> None: while self._send_queue: data, addr = self._send_queue[0] try: if addr is not None: self._sock.sendto(data, addr) else: self._sock.send(data) except BlockingIOError: return except OSError as exc: self._protocol.error_received(exc) self._send_queue.popleft() self._buffer_size -= len(data) self._maybe_resume_protocol() self._write_ready = True self._loop.remove_writer(self._sock.fileno()) if self._closing: self._call_connection_lost(None) def pause_reading(self) -> None: if not self._paused: self._paused = True self._loop.remove_reader(self._sock.fileno()) def resume_reading(self) -> None: if self._paused: self._paused = False self._loop.add_reader(self._sock.fileno(), self._on_readable) def _start(self) -> None: self._loop.call_soon(self._protocol.connection_made, self) self._loop.add_reader(self._sock.fileno(), self._on_readable) def _on_readable(self) -> None: if self._closing: return self._recv_linux_gro() def _recv_linux_gro(self) -> None: ancbufsize = socket.CMSG_SPACE(_UINT16.size) while True: try: data, ancdata, _flags, addr = self._sock.recvmsg( self._recv_buf_size, ancbufsize ) except BlockingIOError: return except OSError as exc: self._protocol.error_received(exc) return if not data: return segment_size = self._gro_segment_size for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level == socket.SOL_UDP and cmsg_type == UDP_LINUX_GRO: (segment_size,) = _UINT16.unpack(cmsg_data[:2]) break if len(data) <= segment_size: self._protocol.datagram_received(data, addr) else: segments = _split_gro_buffer(data, segment_size) self._protocol.datagrams_received(segments, addr) # type: ignore[attr-defined] async def create_udp_endpoint( loop: asyncio.AbstractEventLoop, protocol_factory: Callable[[], asyncio.DatagramProtocol], *, local_addr: tuple[str, int] | None = None, remote_addr: tuple[str, int] | None = None, family: int = socket.AF_UNSPEC, reuse_port: bool = False, gro_segment_size: int = 1280, sock: socket.socket | None = None, ) -> tuple[asyncio.DatagramTransport, asyncio.DatagramProtocol]: if sock is not None: # Caller provided a pre-connected socket — skip creation/bind/connect. try: connected_addr = sock.getpeername() except OSError: connected_addr = None else: # 1. Resolve Addresses if family == socket.AF_UNSPEC: target_addr = local_addr or remote_addr if target_addr: infos = await loop.getaddrinfo( target_addr[0], target_addr[1], type=socket.SOCK_DGRAM ) family = infos[0][0] else: family = socket.AF_INET # 2. Create Socket sock = socket.socket(family, socket.SOCK_DGRAM) sock.setblocking(False) if reuse_port: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if local_addr: sock.bind(local_addr) connected_addr = None if remote_addr: await loop.sock_connect(sock, remote_addr) connected_addr = remote_addr # 3. Determine capabilities — the caller is responsible for # enabling GRO via setsockopt before handing us the socket. gro_enabled = _sock_has_gro(sock) gso_enabled = _sock_has_gso(sock) if not gro_enabled and not gso_enabled: return await loop.create_datagram_endpoint( lambda: protocol_factory(), sock=sock ) # 4. Wire up optimized transport protocol = protocol_factory() transport = _OptimizedDatagramTransport( loop=loop, sock=sock, protocol=protocol, address=connected_addr, gro_enabled=gro_enabled, gso_enabled=gso_enabled, gro_segment_size=gro_segment_size, ) transport._start() return transport, protocol class DatagramReader: """API-compatible with ``asyncio.StreamReader`` (duck-typed) so that ``AsyncSocket`` can assign an instance to ``self._reader`` and the existing ``recv()`` code works unchanged. When GRO delivers multiple coalesced segments in a single syscall, ``feed_datagrams()`` stores them as a single ``list[bytes]`` entry. ``read()`` then returns that list directly so the caller can feed all segments to the QUIC state-machine in one pass before probing — avoiding the per-datagram recv→feed→probe round-trip overhead.""" def __init__(self) -> None: self._buffer: deque[bytes | list[bytes]] = collections.deque() self._waiter: asyncio.Future[None] | None = None self._exception: BaseException | None = None self._eof = False def feed_datagram(self, data: bytes, addr: Any) -> None: """Feed a single (non-coalesced) datagram.""" self._buffer.append(data) self._wake_waiter() def feed_datagrams(self, data: list[bytes], addr: Any) -> None: """Feed a batch of coalesced datagrams as a single entry.""" self._buffer.append(data) self._wake_waiter() def set_exception(self, exc: BaseException) -> None: self._exception = exc self._wake_waiter() def connection_lost(self, exc: BaseException | None) -> None: self._eof = True if exc is not None: self._exception = exc self._wake_waiter() def _wake_waiter(self) -> None: waiter = self._waiter if waiter is not None and not waiter.done(): waiter.set_result(None) async def read(self, n: int = -1) -> bytes | list[bytes]: """Return the next entry from the buffer. * ``bytes`` — a single datagram (non-coalesced). * ``list[bytes]`` — a batch of coalesced datagrams from one GRO syscall. * ``b""`` — EOF. """ if self._buffer: return self._buffer.popleft() if self._exception is not None: raise self._exception if self._eof: return b"" self._waiter = asyncio.get_running_loop().create_future() try: await self._waiter finally: self._waiter = None if self._buffer: return self._buffer.popleft() if self._exception is not None: raise self._exception return b"" class DatagramWriter: """API-compatible with ``asyncio.StreamWriter`` (duck-typed) so that ``AsyncSocket`` can assign an instance to ``self._writer`` and the existing ``sendall()``, ``close()``, ``wait_for_close()`` code works unchanged.""" def __init__( self, transport: asyncio.DatagramTransport, ) -> None: self._transport = transport self._address: tuple[str, int] | None = transport.get_extra_info("peername") self._closed_event = asyncio.Event() self._paused = False self._drain_waiter: asyncio.Future[None] | None = None @property def transport(self) -> asyncio.DatagramTransport: return self._transport def write(self, data: bytes | bytearray | memoryview | list[bytes]) -> None: if self._transport.is_closing(): return if isinstance(data, list): if hasattr(self._transport, "sendto_many"): self._transport.sendto_many(data) else: # Plain asyncio transport — send individually for dgram in data: self._transport.sendto(dgram, self._address) else: self._transport.sendto(bytes(data), self._address) async def drain(self) -> None: if not self._paused: return self._drain_waiter = asyncio.get_running_loop().create_future() try: await self._drain_waiter finally: self._drain_waiter = None def close(self) -> None: self._transport.close() async def wait_closed(self) -> None: await self._closed_event.wait() def get_extra_info(self, name: str, default: Any = None) -> Any: return self._transport.get_extra_info(name, default) def _pause_writing(self) -> None: self._paused = True def _resume_writing(self) -> None: self._paused = False waiter = self._drain_waiter if waiter is not None and not waiter.done(): waiter.set_result(None) class _DatagramBridgeProtocol(asyncio.DatagramProtocol): """Bridges ``asyncio.DatagramProtocol`` callbacks to ``DatagramReader`` / ``DatagramWriter``.""" def __init__(self, reader: DatagramReader) -> None: self._reader = reader self._writer: DatagramWriter | None = None def connection_made(self, transport: asyncio.BaseTransport) -> None: pass # transport is already wired via DatagramWriter def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: self._reader.feed_datagram(data, addr) def datagrams_received(self, data: list[bytes], addr: tuple[str, int]) -> None: self._reader.feed_datagrams(data, addr) def error_received(self, exc: Exception) -> None: self._reader.set_exception(exc) def connection_lost(self, exc: BaseException | None) -> None: self._reader.connection_lost(exc) if self._writer is not None: self._writer._closed_event.set() def pause_writing(self) -> None: if self._writer is not None: self._writer._pause_writing() def resume_writing(self) -> None: if self._writer is not None: self._writer._resume_writing() async def open_dgram_connection( remote_addr: tuple[str, int] | None = None, *, local_addr: tuple[str, int] | None = None, family: int = socket.AF_UNSPEC, sock: socket.socket | None = None, gro_segment_size: int = 1280, ) -> tuple[DatagramReader, DatagramWriter]: loop = asyncio.get_running_loop() reader = DatagramReader() protocol = _DatagramBridgeProtocol(reader) transport, _ = await create_udp_endpoint( loop, lambda: protocol, local_addr=local_addr, remote_addr=remote_addr, family=family, gro_segment_size=gro_segment_size, sock=sock, ) writer = DatagramWriter(transport) protocol._writer = writer return reader, writer