fix: 포트 충돌 회피 — note_bridge 8098, intent_service 8099

Jellyfin(8096), OrbStack(8097) 포트 충돌으로 변경.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Hyungi Ahn
2026-03-19 13:53:55 +09:00
parent dc08d29509
commit c2257d3a86
2709 changed files with 619549 additions and 10 deletions

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import logging
from .asyncio import QuicConnectionProtocol, connect, serve
from .h3 import events as h3_events
from .h3.connection import H3Connection, ProtocolError
from .h3.exceptions import H3Error, NoAvailablePushIDError
from .quic import events as quic_events
from .quic.configuration import QuicConfiguration
from .quic.connection import QuicConnection, QuicConnectionError
from .quic.logger import QuicFileLogger, QuicLogger
from .quic.packet import QuicProtocolVersion
from .tls import CipherSuite, SessionTicket
__version__ = "1.6.0"
__all__ = (
"connect",
"QuicConnectionProtocol",
"serve",
"h3_events",
"H3Error",
"H3Connection",
"NoAvailablePushIDError",
"quic_events",
"QuicConfiguration",
"QuicConnection",
"QuicConnectionError",
"QuicProtocolVersion",
"QuicFileLogger",
"QuicLogger",
"ProtocolError",
"CipherSuite",
"SessionTicket",
"__version__",
)
# Attach a NullHandler to the top level logger by default
# https://docs.python.org/3.3/howto/logging.html#configuring-logging-for-a-library
logging.getLogger("quic").addHandler(logging.NullHandler())
logging.getLogger("http3").addHandler(logging.NullHandler())

View File

@@ -0,0 +1,7 @@
from __future__ import annotations
import sys
DATACLASS_KWARGS = {"slots": True} if sys.version_info >= (3, 10) else {}
UINT_VAR_MAX = 0x3FFFFFFFFFFFFFFF
UINT_VAR_MAX_SIZE = 8

Binary file not shown.

View File

@@ -0,0 +1,361 @@
"""
Everything within that module is off the semver guarantees.
You use it, you deal with unexpected breakage. Anytime, anywhere.
You'd be better off using cryptography directly.
This module serve exclusively qh3 interests. You have been warned.
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Sequence
class DecompressionFailed(Exception): ...
class DecoderStreamError(Exception): ...
class EncoderStreamError(Exception): ...
class StreamBlocked(Exception): ...
class QpackDecoder:
def __init__(self, max_table_capacity: int, blocked_streams: int) -> None: ...
def feed_encoder(self, data: bytes) -> None: ...
def feed_header(
self, stream_id: int, data: bytes
) -> tuple[bytes, list[tuple[bytes, bytes]]]: ...
def resume_header(
self, stream_id: int
) -> tuple[bytes, list[tuple[bytes, bytes]]]: ...
class QpackEncoder:
def apply_settings(
self, max_table_capacity: int, dyn_table_capacity: int, blocked_streams: int
) -> bytes: ...
def encode(
self, stream_id: int, headers: list[tuple[bytes, bytes]]
) -> tuple[bytes, bytes]: ...
def feed_decoder(self, data: bytes) -> None: ...
class AeadChaCha20Poly1305:
def __init__(self, key: bytes, iv: bytes) -> None: ...
def encrypt(
self, packet_number: int, data: bytes, associated_data: bytes
) -> bytes: ...
def decrypt(
self, packet_number: int, data: bytes, associated_data: bytes
) -> bytes: ...
class AeadAes256Gcm:
def __init__(self, key: bytes, iv: bytes) -> None: ...
def encrypt(
self, packet_number: int, data: bytes, associated_data: bytes
) -> bytes: ...
def decrypt(
self, packet_number: int, data: bytes, associated_data: bytes
) -> bytes: ...
class AeadAes128Gcm:
def __init__(self, key: bytes, iv: bytes) -> None: ...
def encrypt(
self, packet_number: int, data: bytes, associated_data: bytes
) -> bytes: ...
def encrypt_with_nonce(
self, nonce: bytes, data: bytes, associated_data: bytes
) -> bytes: ...
def decrypt(
self, packet_number: int, data: bytes, associated_data: bytes
) -> bytes: ...
class ServerVerifier:
def __init__(self, authorities: list[bytes]) -> None: ...
def verify(
self,
peer: bytes,
intermediaries: list[bytes],
server_name: str,
ocsp_response: bytes,
) -> None: ...
class TlsCertUsage(Enum):
ServerAuth = 0
ClientAuth = 1
Both = 2
Other = 3
class Certificate:
"""
A (very) straightforward class to expose a parsed X509 certificate.
This is hazardous material, nothing in there is guaranteed to
remain backward compatible.
Use with care...
"""
def __init__(self, certificate_der: bytes) -> None: ...
@property
def subject(self):
list[tuple[str, str, bytes]]
@property
def issuer(self):
list[tuple[str, str, bytes]]
@property
def not_valid_after(self) -> int: ...
@property
def not_valid_before(self) -> int: ...
@property
def serial_number(self) -> str: ...
def get_extension_for_oid(self, oid: str) -> list[tuple[str, bool, bytes]]: ...
@property
def version(self) -> int: ...
def get_ocsp_endpoints(self) -> list[bytes]: ...
def get_crl_endpoints(self) -> list[bytes]: ...
def get_issuer_endpoints(self) -> list[bytes]: ...
def get_subject_alt_names(self) -> list[bytes]: ...
def public_bytes(self) -> bytes: ...
def public_key(self) -> bytes: ...
@property
def self_signed(self) -> bool: ...
@property
def is_ca(self) -> bool: ...
@property
def usage(self) -> TlsCertUsage: ...
def serialize(self) -> bytes: ...
@staticmethod
def deserialize(src: bytes) -> Certificate: ...
class Rsa:
"""
This binding host a RSA Private/Public Keys.
Use Oaep (padding) + SHA256 under. Not customizable.
"""
def __init__(self, key_size: int) -> None: ...
def encrypt(self, data: bytes) -> bytes: ...
def decrypt(self, data: bytes) -> bytes: ...
class EcPrivateKey:
def __init__(self, der_key: bytes, curve_type: int, is_pkcs8: bool) -> None: ...
def public_key(self) -> bytes: ...
def sign(self, data: bytes) -> bytes: ...
@property
def curve_type(self) -> int: ...
class Ed25519PrivateKey:
def __init__(self, pkcs8: bytes) -> None: ...
def public_key(self) -> bytes: ...
def sign(self, data: bytes) -> bytes: ...
class DsaPrivateKey:
def __init__(self, pkcs8: bytes) -> None: ...
def public_key(self) -> bytes: ...
def sign(self, data: bytes) -> bytes: ...
class RsaPrivateKey:
def __init__(self, pkcs8: bytes) -> None: ...
def public_key(self) -> bytes: ...
def sign(self, data: bytes, padding, hash_size: int) -> bytes: ...
def verify_with_public_key(
public_key_raw: bytes, algorithm: int, message: bytes, signature: bytes
) -> None: ...
class X25519ML768KeyExchange:
def __init__(self) -> None: ...
def public_key(self) -> bytes: ...
def exchange(self, peer_public_key: bytes) -> bytes: ...
def shared_ciphertext(self) -> bytes: ...
class X25519KeyExchange:
def __init__(self) -> None: ...
def public_key(self) -> bytes: ...
def exchange(self, peer_public_key: bytes) -> bytes: ...
class ECDHP256KeyExchange:
def __init__(self) -> None: ...
def public_key(self) -> bytes: ...
def exchange(self, peer_public_key: bytes) -> bytes: ...
class ECDHP384KeyExchange:
def __init__(self) -> None: ...
def public_key(self) -> bytes: ...
def exchange(self, peer_public_key: bytes) -> bytes: ...
class ECDHP521KeyExchange:
def __init__(self) -> None: ...
def public_key(self) -> bytes: ...
def exchange(self, peer_public_key: bytes) -> bytes: ...
class CryptoError(Exception): ...
class KeyType(Enum):
ECDSA_P256 = 0
ECDSA_P384 = 1
ECDSA_P521 = 2
ED25519 = 3
DSA = 4
RSA = 5
class PrivateKeyInfo:
"""
Load a PEM private key and extract valuable info from it.
Does two things, provide a DER encoded key and hint
toward its nature (eg. EC, RSA, DSA, etc...)
"""
def __init__(self, raw_pem_content: bytes, password: bytes | None) -> None: ...
def public_bytes(self) -> bytes: ...
def get_type(self) -> KeyType: ...
class SelfSignedCertificateError(Exception): ...
class InvalidNameCertificateError(Exception): ...
class ExpiredCertificateError(Exception): ...
class UnacceptableCertificateError(Exception): ...
class SignatureError(Exception): ...
class QUICHeaderProtection:
def __init__(self, algorithm: str, key: bytes) -> None: ...
def apply(self, plain_header: bytes, protected_payload: bytes) -> bytes: ...
def remove(self, packet: bytes, pn_offset: int) -> tuple[bytes, int]: ...
def mask(self, sample: bytes) -> bytes: ...
class ReasonFlags(Enum):
unspecified = 0
key_compromise = 1
ca_compromise = 2
affiliation_changed = 3
superseded = 4
cessation_of_operation = 5
certificate_hold = 6
privilege_withdrawn = 9
aa_compromise = 10
remove_from_crl = 8
class OCSPResponseStatus(Enum):
SUCCESSFUL = 0
MALFORMED_REQUEST = 1
INTERNAL_ERROR = 2
TRY_LATER = 3
SIG_REQUIRED = 5
UNAUTHORIZED = 6
class OCSPCertStatus(Enum):
GOOD = 0
REVOKED = 1
UNKNOWN = 2
class OCSPResponse:
def __init__(self, raw_response: bytes) -> None: ...
@property
def next_update(self) -> int: ...
@property
def response_status(self) -> OCSPResponseStatus: ...
@property
def certificate_status(self) -> OCSPCertStatus: ...
@property
def revocation_reason(self) -> ReasonFlags | None: ...
def serialize(self) -> bytes: ...
@staticmethod
def deserialize(src: bytes) -> OCSPResponse: ...
def authenticate_for(self, issuer_der: bytes) -> bool: ...
class OCSPRequest:
def __init__(self, peer_certificate: bytes, issuer_certificate: bytes) -> None: ...
def public_bytes(self) -> bytes: ...
class BufferReadError(ValueError): ...
class BufferWriteError(ValueError): ...
class Buffer:
def __init__(self, capacity: int = 0, data: bytes | None = None) -> None: ...
@property
def capacity(self) -> int: ...
@property
def data(self) -> bytes: ...
def data_slice(self, start: int, end: int) -> bytes: ...
def eof(self) -> bool: ...
def seek(self, pos: int) -> None: ...
def tell(self) -> int: ...
def pull_bytes(self, length: int) -> bytes: ...
def pull_uint8(self) -> int: ...
def pull_uint16(self) -> int: ...
def pull_uint24(self) -> int: ... # only for OCSP resp parsing!
def pull_uint32(self) -> int: ...
def pull_uint64(self) -> int: ...
def pull_uint_var(self) -> int: ...
def push_bytes(self, value: bytes) -> None: ...
def push_uint8(self, value: int) -> None: ...
def push_uint16(self, value: int) -> None: ...
def push_uint32(self, value: int) -> None: ...
def push_uint64(self, value: int) -> None: ...
def push_uint_var(self, value: int) -> None: ...
def encode_uint_var(value: int) -> bytes:
"""
Encode a variable-length unsigned integer.
"""
def size_uint_var(value: int) -> int:
"""
Return the number of bytes required to encode the given value
as a QUIC variable-length unsigned integer.
"""
def idna_encode(text: str) -> bytes:
"""using UTS46"""
def idna_decode(src: bytes) -> str:
"""using UTS46"""
class RangeSet(Sequence):
def __init__(self) -> None: ...
def add(self, start: int, stop: int | None = None) -> None: ...
def bounds(self) -> tuple[int, int]: ...
def shift(self) -> tuple[int, int]: ...
def subtract(self, start: int, stop: int) -> None: ...
def __contains__(self, val: object) -> bool: ...
def __eq__(self, other: Any) -> bool: ...
def __getitem__(self, key: Any) -> tuple[int, int]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def decode_packet_number(truncated: int, num_bits: int, expected: int) -> int:
"""
Recover a packet number from a truncated packet number.
See: Appendix A - Sample Packet Number Decoding Algorithm
"""
class QuicPacketPacer:
def __init__(self, max_datagram_size: int) -> None: ...
def next_send_time(self, now: float) -> float | None: ...
def update_after_send(self, now: float) -> None: ...
def update_bucket(self, now: float) -> None: ...
def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: ...
class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""
def add_rtt(self, rtt: float) -> None: ...
def is_rtt_increasing(self, rtt: float, now: float) -> bool: ...
class RevokedCertificate:
serial_number: str
reason: ReasonFlags
expired_at: int
class CertificateRevocationList:
def __init__(self, crl: bytes) -> None: ...
def is_revoked(self, serial_number: str) -> RevokedCertificate | None: ...
def serialize(self) -> bytes: ...
@staticmethod
def deserialize(src: bytes) -> CertificateRevocationList: ...
def __len__(self) -> int: ...
@property
def issuer(self) -> str: ...
@property
def last_updated_at(self) -> int: ...
@property
def next_update_at(self) -> int: ...
def authenticate_for(self, issuer_der: bytes) -> bool: ...
def rebuild_chain(leaf: bytes, intermediates: list[bytes]) -> list[bytes]: ...

View File

@@ -0,0 +1,3 @@
from .client import connect # noqa
from .protocol import QuicConnectionProtocol # noqa
from .server import serve # noqa

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import asyncio
import ipaddress
import socket
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Callable, cast
from ..quic.configuration import QuicConfiguration
from ..quic.connection import QuicConnection
from ..tls import SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["connect"]
# keep compatibility for Python 3.7 on Windows
if not hasattr(socket, "IPPROTO_IPV6"):
socket.IPPROTO_IPV6 = 41
@asynccontextmanager
async def connect(
host: str,
port: int,
*,
configuration: QuicConfiguration | None = None,
create_protocol: Callable | None = QuicConnectionProtocol,
session_ticket_handler: SessionTicketHandler | None = None,
stream_handler: QuicStreamHandler | None = None,
wait_connected: bool = True,
local_port: int = 0,
) -> AsyncGenerator[QuicConnectionProtocol]:
"""
Connect to a QUIC server at the given `host` and `port`.
:meth:`connect()` returns an awaitable. Awaiting it yields a
:class:`~qh3.asyncio.QuicConnectionProtocol` which can be used to
create streams.
:func:`connect` also accepts the following optional arguments:
* ``configuration`` is a :class:`~qh3.quic.configuration.QuicConfiguration`
configuration object.
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~qh3.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~qh3.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is received.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
* ``local_port`` is the UDP port number that this client wants to bind.
"""
loop = asyncio.get_running_loop()
local_host = "::"
# if host is not an IP address, pass it to enable SNI
try:
ipaddress.ip_address(host)
server_name = None
except ValueError:
server_name = host
# lookup remote address
infos = await loop.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
addr = infos[0][4]
if len(addr) == 2:
addr = ("::ffff:" + addr[0], addr[1], 0, 0)
# prepare QUIC connection
if configuration is None:
configuration = QuicConfiguration(is_client=True)
if configuration.server_name is None:
configuration.server_name = server_name
connection = QuicConnection(
configuration=configuration, session_ticket_handler=session_ticket_handler
)
# explicitly enable IPv4/IPv6 dual stack
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
completed = False
try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind((local_host, local_port, 0, 0))
completed = True
finally:
if not completed:
sock.close()
# connect
transport, protocol = await loop.create_datagram_endpoint(
lambda: create_protocol(connection, stream_handler=stream_handler),
sock=sock,
)
protocol = cast(QuicConnectionProtocol, protocol)
try:
protocol.connect(addr)
if wait_connected:
await protocol.wait_connected()
yield protocol
finally:
protocol.close()
await protocol.wait_closed()
transport.close()

View File

@@ -0,0 +1,254 @@
from __future__ import annotations
import asyncio
from typing import Any, Callable, cast
from ..quic import events
from ..quic.connection import NetworkAddress, QuicConnection
QuicConnectionIdHandler = Callable[[bytes], None]
QuicStreamHandler = Callable[[asyncio.StreamReader, asyncio.StreamWriter], None]
class QuicConnectionProtocol(asyncio.DatagramProtocol):
def __init__(
self, quic: QuicConnection, stream_handler: QuicStreamHandler | None = None
):
loop = asyncio.get_running_loop()
self._closed = asyncio.Event()
self._connected = False
self._connected_waiter: asyncio.Future[None] | None = None
self._loop = loop
self._ping_waiters: dict[int, asyncio.Future[None]] = {}
self._quic = quic
self._stream_readers: dict[int, asyncio.StreamReader] = {}
self._timer: asyncio.TimerHandle | None = None
self._timer_at: float | None = None
self._transmit_task: asyncio.Handle | None = None
self._transport: asyncio.DatagramTransport | None = None
# callbacks
self._connection_id_issued_handler: QuicConnectionIdHandler = lambda c: None
self._connection_id_retired_handler: QuicConnectionIdHandler = lambda c: None
self._connection_terminated_handler: Callable[[], None] = lambda: None
if stream_handler is not None:
self._stream_handler = stream_handler
else:
self._stream_handler = lambda r, w: None
def change_connection_id(self) -> None:
"""
Change the connection ID used to communicate with the peer.
The previous connection ID will be retired.
"""
self._quic.change_connection_id()
self.transmit()
def close(self) -> None:
"""
Close the connection.
"""
self._quic.close()
self.transmit()
def connect(self, addr: NetworkAddress) -> None:
"""
Initiate the TLS handshake.
This method can only be called for clients and a single time.
"""
self._quic.connect(addr, now=self._loop.time())
self.transmit()
async def create_stream(
self, is_unidirectional: bool = False
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
Create a QUIC stream and return a pair of (reader, writer) objects.
The returned reader and writer objects are instances of
:class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` classes.
"""
stream_id = self._quic.get_next_available_stream_id(
is_unidirectional=is_unidirectional
)
return self._create_stream(stream_id)
def request_key_update(self) -> None:
"""
Request an update of the encryption keys.
"""
self._quic.request_key_update()
self.transmit()
async def ping(self) -> None:
"""
Ping the peer and wait for the response.
"""
waiter = self._loop.create_future()
uid = id(waiter)
self._ping_waiters[uid] = waiter
self._quic.send_ping(uid)
self.transmit()
await asyncio.shield(waiter)
def transmit(self) -> None:
"""
Send pending datagrams to the peer and arm the timer if needed.
"""
self._transmit_task = None
# send datagrams
for data, addr in self._quic.datagrams_to_send(now=self._loop.time()):
self._transport.sendto(data, addr)
# re-arm timer
timer_at = self._quic.get_timer()
if self._timer is not None and self._timer_at != timer_at:
self._timer.cancel()
self._timer = None
if self._timer is None and timer_at is not None:
self._timer = self._loop.call_at(timer_at, self._handle_timer)
self._timer_at = timer_at
async def wait_closed(self) -> None:
"""
Wait for the connection to be closed.
"""
await self._closed.wait()
async def wait_connected(self) -> None:
"""
Wait for the TLS handshake to complete.
"""
assert self._connected_waiter is None, "already awaiting connected"
if not self._connected:
self._connected_waiter = self._loop.create_future()
await asyncio.shield(self._connected_waiter)
# asyncio.Transport
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes | str, addr: NetworkAddress) -> None:
self._quic.receive_datagram(cast(bytes, data), addr, now=self._loop.time())
self._process_events()
self.transmit()
# overridable
def quic_event_received(self, event: events.QuicEvent) -> None:
"""
Called when a QUIC event is received.
Reimplement this in your subclass to handle the events.
"""
# FIXME: move this to a subclass
if isinstance(event, events.ConnectionTerminated):
for reader in self._stream_readers.values():
reader.feed_eof()
elif isinstance(event, events.StreamDataReceived):
reader = self._stream_readers.get(event.stream_id, None)
if reader is None:
reader, writer = self._create_stream(event.stream_id)
self._stream_handler(reader, writer)
reader.feed_data(event.data)
if event.end_stream:
reader.feed_eof()
# private
def _create_stream(
self, stream_id: int
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
adapter = QuicStreamAdapter(self, stream_id)
reader = asyncio.StreamReader()
writer = asyncio.StreamWriter(adapter, None, reader, self._loop)
self._stream_readers[stream_id] = reader
return reader, writer
def _handle_timer(self) -> None:
now = max(self._timer_at, self._loop.time())
self._timer = None
self._timer_at = None
self._quic.handle_timer(now=now)
self._process_events()
self.transmit()
def _process_events(self) -> None:
event = self._quic.next_event()
while event is not None:
if isinstance(event, events.ConnectionIdIssued):
self._connection_id_issued_handler(event.connection_id)
elif isinstance(event, events.ConnectionIdRetired):
self._connection_id_retired_handler(event.connection_id)
elif isinstance(event, events.ConnectionTerminated):
self._connection_terminated_handler()
# abort connection waiter
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected_waiter = None
waiter.set_exception(ConnectionError)
# abort ping waiters
for waiter in self._ping_waiters.values():
waiter.set_exception(ConnectionError)
self._ping_waiters.clear()
self._closed.set()
elif isinstance(event, events.HandshakeCompleted):
if self._connected_waiter is not None:
waiter = self._connected_waiter
self._connected = True
self._connected_waiter = None
waiter.set_result(None)
elif isinstance(event, events.PingAcknowledged):
waiter = self._ping_waiters.pop(event.uid, None)
if waiter is not None:
waiter.set_result(None)
self.quic_event_received(event)
event = self._quic.next_event()
def _transmit_soon(self) -> None:
if self._transmit_task is None:
self._transmit_task = self._loop.call_soon(self.transmit)
class QuicStreamAdapter(asyncio.Transport):
def __init__(self, protocol: QuicConnectionProtocol, stream_id: int):
super().__init__()
self.protocol = protocol
self.stream_id = stream_id
self._closing = False
def can_write_eof(self) -> bool:
return True
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""
Get information about the underlying QUIC stream.
"""
if name == "stream_id":
return self.stream_id
def write(self, data) -> None:
self.protocol._quic.send_stream_data(self.stream_id, data)
self.protocol._transmit_soon()
def write_eof(self) -> None:
if self._closing:
return
self._closing = True
self.protocol._quic.send_stream_data(self.stream_id, b"", end_stream=True)
self.protocol._transmit_soon()
def close(self) -> None:
self.write_eof()
def is_closing(self) -> bool:
return self._closing

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
import asyncio
import os
from functools import partial
from typing import Callable, cast
from .._hazmat import Buffer
from ..quic.configuration import QuicConfiguration
from ..quic.connection import NetworkAddress, QuicConnection
from ..quic.packet import (
QuicPacketType,
encode_quic_retry,
encode_quic_version_negotiation,
pull_quic_header,
)
from ..quic.retry import QuicRetryTokenHandler
from ..tls import SessionTicketFetcher, SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler
__all__ = ["serve"]
class QuicServer(asyncio.DatagramProtocol):
def __init__(
self,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: SessionTicketFetcher | None = None,
session_ticket_handler: SessionTicketHandler | None = None,
retry: bool = False,
stream_handler: QuicStreamHandler | None = None,
) -> None:
self._configuration = configuration
self._create_protocol = create_protocol
self._loop = asyncio.get_running_loop()
self._protocols: dict[bytes, QuicConnectionProtocol] = {}
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._transport: asyncio.DatagramTransport | None = None
self._stream_handler = stream_handler
if retry:
self._retry = QuicRetryTokenHandler()
else:
self._retry = None
def close(self):
for protocol in set(self._protocols.values()):
protocol.close()
self._protocols.clear()
self._transport.close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self._transport = cast(asyncio.DatagramTransport, transport)
def datagram_received(self, data: bytes | str, addr: NetworkAddress) -> None:
data = cast(bytes, data)
buf = Buffer(data=data)
try:
header = pull_quic_header(
buf, host_cid_length=self._configuration.connection_id_length
)
except ValueError:
return
# version negotiation
if (
header.version is not None
and header.version not in self._configuration.supported_versions
):
self._transport.sendto(
encode_quic_version_negotiation(
source_cid=header.destination_cid,
destination_cid=header.source_cid,
supported_versions=self._configuration.supported_versions,
),
addr,
)
return
protocol = self._protocols.get(header.destination_cid, None)
original_destination_connection_id: bytes | None = None
retry_source_connection_id: bytes | None = None
if (
protocol is None
and len(data) >= 1200
and header.packet_type == QuicPacketType.INITIAL
):
# retry
if self._retry is not None:
if not header.token:
# create a retry token
source_cid = os.urandom(8)
self._transport.sendto(
encode_quic_retry(
version=header.version,
source_cid=source_cid,
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self._retry.create_token(
addr, header.destination_cid, source_cid
),
),
addr,
)
return
else:
# validate retry token
try:
(
original_destination_connection_id,
retry_source_connection_id,
) = self._retry.validate_token(addr, header.token)
except ValueError:
return
else:
original_destination_connection_id = header.destination_cid
# create new connection
connection = QuicConnection(
configuration=self._configuration,
original_destination_connection_id=original_destination_connection_id,
retry_source_connection_id=retry_source_connection_id,
session_ticket_fetcher=self._session_ticket_fetcher,
session_ticket_handler=self._session_ticket_handler,
)
protocol = self._create_protocol(
connection, stream_handler=self._stream_handler
)
protocol.connection_made(self._transport)
# register callbacks
protocol._connection_id_issued_handler = partial(
self._connection_id_issued, protocol=protocol
)
protocol._connection_id_retired_handler = partial(
self._connection_id_retired, protocol=protocol
)
protocol._connection_terminated_handler = partial(
self._connection_terminated, protocol=protocol
)
self._protocols[header.destination_cid] = protocol
self._protocols[connection.host_cid] = protocol
if protocol is not None:
protocol.datagram_received(data, addr)
def _connection_id_issued(self, cid: bytes, protocol: QuicConnectionProtocol):
self._protocols[cid] = protocol
def _connection_id_retired(
self, cid: bytes, protocol: QuicConnectionProtocol
) -> None:
assert self._protocols[cid] == protocol
del self._protocols[cid]
def _connection_terminated(self, protocol: QuicConnectionProtocol):
for cid, proto in list(self._protocols.items()):
if proto == protocol:
del self._protocols[cid]
async def serve(
host: str,
port: int,
*,
configuration: QuicConfiguration,
create_protocol: Callable = QuicConnectionProtocol,
session_ticket_fetcher: SessionTicketFetcher | None = None,
session_ticket_handler: SessionTicketHandler | None = None,
retry: bool = False,
stream_handler: QuicStreamHandler = None,
) -> QuicServer:
"""
Start a QUIC server at the given `host` and `port`.
:func:`serve` requires a :class:`~qh3.quic.configuration.QuicConfiguration`
containing TLS certificate and private key as the ``configuration`` argument.
:func:`serve` also accepts the following optional arguments:
* ``create_protocol`` allows customizing the :class:`~asyncio.Protocol` that
manages the connection. It should be a callable or class accepting the same
arguments as :class:`~qh3.asyncio.QuicConnectionProtocol` and returning
an instance of :class:`~qh3.asyncio.QuicConnectionProtocol` or a subclass.
* ``session_ticket_fetcher`` is a callback which is invoked by the TLS
engine when a session ticket is presented by the peer. It should return
the session ticket with the specified ID or `None` if it is not found.
* ``session_ticket_handler`` is a callback which is invoked by the TLS
engine when a new session ticket is issued. It should store the session
ticket for future lookup.
* ``retry`` specifies whether client addresses should be validated prior to
the cryptographic handshake using a retry packet.
* ``stream_handler`` is a callback which is invoked whenever a stream is
created. It must accept two arguments: a :class:`asyncio.StreamReader`
and a :class:`asyncio.StreamWriter`.
"""
loop = asyncio.get_running_loop()
_, protocol = await loop.create_datagram_endpoint(
lambda: QuicServer(
configuration=configuration,
create_protocol=create_protocol,
session_ticket_fetcher=session_ticket_fetcher,
session_ticket_handler=session_ticket_handler,
retry=retry,
stream_handler=stream_handler,
),
local_addr=(host, port),
)
return protocol

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,116 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple
Headers = List[Tuple[bytes, bytes]]
class H3Event:
"""
Base class for HTTP/3 events.
"""
@dataclass
class DataReceived(H3Event):
"""
The DataReceived event is fired whenever data is received on a stream from
the remote peer.
"""
data: bytes
"The data which was received."
stream_id: int
"The ID of the stream the data was received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: int | None = None
"The Push ID or `None` if this is not a push."
@dataclass
class DatagramReceived(H3Event):
"""
The DatagramReceived is fired whenever a datagram is received from the
the remote peer.
"""
data: bytes
"The data which was received."
flow_id: int
"The ID of the flow the data was received for."
@dataclass
class InformationalHeadersReceived(H3Event):
"""
This event is fired whenever an informational response has been caught inflight!
The stream cannot be ended there.
"""
headers: Headers
"The headers."
stream_id: int
"The ID of the stream the headers were received for."
@dataclass
class HeadersReceived(H3Event):
"""
The HeadersReceived event is fired whenever headers are received.
"""
headers: Headers
"The headers."
stream_id: int
"The ID of the stream the headers were received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
push_id: int | None = None
"The Push ID or `None` if this is not a push."
@dataclass
class PushPromiseReceived(H3Event):
"""
The PushedStreamReceived event is fired whenever a pushed stream has been
received from the remote peer.
"""
headers: Headers
"The request headers."
push_id: int
"The Push ID of the push promise."
stream_id: int
"The Stream ID of the stream that the push is related to."
@dataclass
class WebTransportStreamDataReceived(H3Event):
"""
The WebTransportStreamDataReceived is fired whenever data is received
for a WebTransport stream.
"""
data: bytes
"The data which was received."
stream_id: int
"The ID of the stream the data was received for."
stream_ended: bool
"Whether the STREAM frame had the FIN bit set."
session_id: int
"The ID of the session the data was received for."

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
class H3Error(Exception):
"""
Base class for HTTP/3 exceptions.
"""
class NoAvailablePushIDError(H3Error):
"""
There are no available push IDs left.
"""

View File

@@ -0,0 +1 @@
Marker

View File

@@ -0,0 +1,191 @@
from __future__ import annotations
from dataclasses import dataclass, field
from os import PathLike
from re import split
from typing import TYPE_CHECKING, TextIO
if TYPE_CHECKING:
from .._hazmat import Certificate as X509Certificate
from .._hazmat import DsaPrivateKey, EcPrivateKey, Ed25519PrivateKey, RsaPrivateKey
from ..tls import (
CipherSuite,
SessionTicket,
load_pem_private_key,
load_pem_x509_certificates,
)
from .logger import QuicLogger
from .packet import QuicProtocolVersion
from .packet_builder import PACKET_MAX_SIZE
@dataclass
class QuicConfiguration:
"""
A QUIC configuration.
"""
alpn_protocols: list[str] | None = None
"""
A list of supported ALPN protocols.
"""
connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
"""
idle_timeout: float = 60.0
"""
The idle timeout in seconds.
The connection is terminated if nothing is received for the given duration.
"""
is_client: bool = True
"""
Whether this is the client side of the QUIC connection.
"""
max_data: int = 1048576
"""
Connection-wide flow control limit.
"""
max_datagram_size: int = PACKET_MAX_SIZE
"""
The maximum QUIC payload size in bytes to send, excluding UDP or IP overhead.
"""
probe_datagram_size: bool = True
"""
Enable path MTU discovery. Client-only.
"""
max_stream_data: int = 1048576
"""
Per-stream flow control limit.
"""
quic_logger: QuicLogger | None = None
"""
The :class:`~qh3.quic.logger.QuicLogger` instance to log events to.
"""
secrets_log_file: TextIO = None
"""
A file-like object in which to log traffic secrets.
This is useful to analyze traffic captures with Wireshark.
"""
server_name: str | None = None
"""
The server name to send during the TLS handshake the Server Name Indication.
.. note:: This is only used by clients.
"""
session_ticket: SessionTicket | None = None
"""
The TLS session ticket which should be used for session resumption.
"""
hostname_checks_common_name: bool = False
assert_fingerprint: str | None = None
verify_hostname: bool = True
cadata: bytes | None = None
cafile: str | None = None
capath: str | None = None
certificate: X509Certificate | None = None
certificate_chain: list[X509Certificate] = field(default_factory=list)
cipher_suites: list[CipherSuite] | None = None
initial_rtt: float = 0.1
max_datagram_frame_size: int | None = None
original_version: int | None = None
private_key: (
EcPrivateKey | Ed25519PrivateKey | DsaPrivateKey | RsaPrivateKey | None
) = None
quantum_readiness_test: bool = False
supported_versions: list[int] = field(
default_factory=lambda: [
QuicProtocolVersion.VERSION_1,
QuicProtocolVersion.VERSION_2,
]
)
verify_mode: int | None = None
def load_cert_chain(
self,
certfile: str | bytes | PathLike,
keyfile: str | bytes | PathLike | None = None,
password: bytes | str | None = None,
) -> None:
"""
Load a private key and the corresponding certificate.
"""
if isinstance(certfile, str):
certfile = certfile.encode()
elif isinstance(certfile, PathLike):
certfile = str(certfile).encode()
if keyfile is not None:
if isinstance(keyfile, str):
keyfile = keyfile.encode()
elif isinstance(keyfile, PathLike):
keyfile = str(keyfile).encode()
# we either have the certificate or a file path in certfile/keyfile.
if b"-----BEGIN" not in certfile:
with open(certfile, "rb") as fp:
certfile = fp.read()
if keyfile is not None:
with open(keyfile, "rb") as fp:
keyfile = fp.read()
is_crlf = b"-----\r\n" in certfile
boundary = (
b"-----BEGIN PRIVATE KEY-----\n"
if not is_crlf
else b"-----BEGIN PRIVATE KEY-----\r\n"
)
chunks = split(b"\n" + boundary, certfile)
certificates = load_pem_x509_certificates(chunks[0])
if len(chunks) == 2:
private_key = boundary + chunks[1]
self.private_key = load_pem_private_key(private_key)
self.certificate = certificates[0]
self.certificate_chain = certificates[1:]
if keyfile is not None:
self.private_key = load_pem_private_key(
keyfile,
password=(
password.encode("utf8") if isinstance(password, str) else password
),
)
def load_verify_locations(
self,
cafile: str | None = None,
capath: str | None = None,
cadata: bytes | None = None,
) -> None:
"""
Load a set of "certification authority" (CA) certificates used to
validate other peers' certificates.
"""
self.cafile = cafile
self.capath = capath
self.cadata = cadata

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,284 @@
from __future__ import annotations
import binascii
from typing import Callable
from .._hazmat import (
AeadAes128Gcm,
AeadAes256Gcm,
AeadChaCha20Poly1305,
CryptoError,
decode_packet_number,
)
from .._hazmat import QUICHeaderProtection as HeaderProtection
from ..tls import CipherSuite, cipher_suite_hash, hkdf_expand_label, hkdf_extract
from .packet import (
QuicProtocolVersion,
is_long_header,
)
CIPHER_SUITES = {
CipherSuite.AES_128_GCM_SHA256: (b"aes-128-ecb", b"aes-128-gcm"),
CipherSuite.CHACHA20_POLY1305_SHA256: (b"chacha20", b"chacha20-poly1305"),
CipherSuite.AES_256_GCM_SHA384: (b"aes-256-ecb", b"aes-256-gcm"),
}
INITIAL_CIPHER_SUITE = CipherSuite.AES_128_GCM_SHA256
INITIAL_SALT_VERSION_1 = binascii.unhexlify("38762cf7f55934b34d179ae6a4c80cadccbb7f0a")
INITIAL_SALT_VERSION_2 = binascii.unhexlify("0dede3def700a6db819381be6e269dcbf9bd2ed9")
SAMPLE_SIZE = 16
Callback = Callable[[str], None]
def NoCallback(trigger: str) -> None:
pass
class KeyUnavailableError(CryptoError):
pass
def derive_key_iv_hp(
*, cipher_suite: CipherSuite, secret: bytes, version: int
) -> tuple[bytes, bytes, bytes]:
algorithm = cipher_suite_hash(cipher_suite)
if cipher_suite in [
CipherSuite.AES_256_GCM_SHA384,
CipherSuite.CHACHA20_POLY1305_SHA256,
]:
key_size = 32
else:
key_size = 16
if version == QuicProtocolVersion.VERSION_2:
return (
hkdf_expand_label(algorithm, secret, b"quicv2 key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quicv2 iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quicv2 hp", b"", key_size),
)
else:
return (
hkdf_expand_label(algorithm, secret, b"quic key", b"", key_size),
hkdf_expand_label(algorithm, secret, b"quic iv", b"", 12),
hkdf_expand_label(algorithm, secret, b"quic hp", b"", key_size),
)
class CryptoContext:
__slots__ = (
"aead",
"cipher_suite",
"hp",
"key_phase",
"secret",
"version",
"_setup_cb",
"_teardown_cb",
)
def __init__(
self,
key_phase: int = 0,
setup_cb: Callback = NoCallback,
teardown_cb: Callback = NoCallback,
) -> None:
self.aead: AeadChaCha20Poly1305 | AeadAes128Gcm | AeadAes256Gcm | None = None
self.cipher_suite: CipherSuite | None = None
self.hp: HeaderProtection | None = None
self.key_phase = key_phase
self.secret: bytes | None = None
self.version: int | None = None
self._setup_cb = setup_cb
self._teardown_cb = teardown_cb
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> tuple[bytes, bytes, int, bool]:
if self.aead is None:
raise KeyUnavailableError("Decryption key is not available")
# header protection
plain_header, packet_number = self.hp.remove(packet, encrypted_offset)
first_byte = plain_header[0]
# packet number
pn_length = (first_byte & 0x03) + 1
packet_number = decode_packet_number(
packet_number, pn_length * 8, expected_packet_number
)
# detect key phase change
crypto = self
if not is_long_header(first_byte):
key_phase = (first_byte & 4) >> 2
if key_phase != self.key_phase:
crypto = next_key_phase(self)
# payload protection
payload = crypto.aead.decrypt(
packet_number, packet[len(plain_header) :], plain_header
)
return plain_header, payload, packet_number, crypto != self
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
assert self.is_valid(), "Encryption key is not available"
# payload protection
protected_payload = self.aead.encrypt(
packet_number, plain_payload, plain_header
)
# header protection
return self.hp.apply(plain_header, protected_payload)
def is_valid(self) -> bool:
return self.aead is not None
def setup(self, *, cipher_suite: CipherSuite, secret: bytes, version: int) -> None:
hp_cipher_name, aead_cipher_name = CIPHER_SUITES[cipher_suite]
key, iv, hp = derive_key_iv_hp(
cipher_suite=cipher_suite,
secret=secret,
version=version,
)
if aead_cipher_name == b"chacha20-poly1305":
self.aead = AeadChaCha20Poly1305(key, iv)
elif aead_cipher_name == b"aes-256-gcm":
self.aead = AeadAes256Gcm(key, iv)
elif aead_cipher_name == b"aes-128-gcm":
self.aead = AeadAes128Gcm(key, iv)
else:
raise CryptoError(f"Invalid cipher name: {aead_cipher_name.decode()}")
self.cipher_suite = cipher_suite
self.hp = HeaderProtection(hp_cipher_name.decode(), hp)
self.secret = secret
self.version = version
# trigger callback
self._setup_cb("tls")
def teardown(self) -> None:
self.aead = None
self.cipher_suite = None
self.hp = None
self.secret = None
# trigger callback
self._teardown_cb("tls")
def apply_key_phase(self: CryptoContext, crypto: CryptoContext, trigger: str) -> None:
self.aead = crypto.aead
self.key_phase = crypto.key_phase
self.secret = crypto.secret
# trigger callback
self._setup_cb(trigger)
def next_key_phase(self: CryptoContext) -> CryptoContext:
algorithm = cipher_suite_hash(self.cipher_suite)
crypto = CryptoContext(key_phase=int(not self.key_phase))
crypto.setup(
cipher_suite=self.cipher_suite,
secret=hkdf_expand_label(
algorithm, self.secret, b"quic ku", b"", int(algorithm / 8)
),
version=self.version,
)
return crypto
class CryptoPair:
__slots__ = (
"aead_tag_size",
"recv",
"send",
"_update_key_requested",
)
def __init__(
self,
recv_setup_cb: Callback = NoCallback,
recv_teardown_cb: Callback = NoCallback,
send_setup_cb: Callback = NoCallback,
send_teardown_cb: Callback = NoCallback,
) -> None:
self.aead_tag_size = 16
self.recv = CryptoContext(setup_cb=recv_setup_cb, teardown_cb=recv_teardown_cb)
self.send = CryptoContext(setup_cb=send_setup_cb, teardown_cb=send_teardown_cb)
self._update_key_requested = False
def decrypt_packet(
self, packet: bytes, encrypted_offset: int, expected_packet_number: int
) -> tuple[bytes, bytes, int]:
plain_header, payload, packet_number, update_key = self.recv.decrypt_packet(
packet, encrypted_offset, expected_packet_number
)
if update_key:
self._update_key("remote_update")
return plain_header, payload, packet_number
def encrypt_packet(
self, plain_header: bytes, plain_payload: bytes, packet_number: int
) -> bytes:
if self._update_key_requested:
self._update_key("local_update")
return self.send.encrypt_packet(plain_header, plain_payload, packet_number)
def setup_initial(self, cid: bytes, is_client: bool, version: int) -> None:
if is_client:
recv_label, send_label = b"server in", b"client in"
else:
recv_label, send_label = b"client in", b"server in"
if version == QuicProtocolVersion.VERSION_2:
initial_salt = INITIAL_SALT_VERSION_2
else:
initial_salt = INITIAL_SALT_VERSION_1
algorithm = cipher_suite_hash(INITIAL_CIPHER_SUITE)
digest_size = int(algorithm / 8)
initial_secret = hkdf_extract(algorithm, initial_salt, cid)
self.recv.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, recv_label, b"", digest_size
),
version=version,
)
self.send.setup(
cipher_suite=INITIAL_CIPHER_SUITE,
secret=hkdf_expand_label(
algorithm, initial_secret, send_label, b"", digest_size
),
version=version,
)
def teardown(self) -> None:
self.recv.teardown()
self.send.teardown()
def update_key(self) -> None:
self._update_key_requested = True
@property
def key_phase(self) -> int:
if self._update_key_requested:
return int(not self.recv.key_phase)
else:
return self.recv.key_phase
def _update_key(self, trigger: str) -> None:
apply_key_phase(self.recv, next_key_phase(self.recv), trigger=trigger)
apply_key_phase(self.send, next_key_phase(self.send), trigger=trigger)
self._update_key_requested = False

View File

@@ -0,0 +1,127 @@
from __future__ import annotations
from dataclasses import dataclass
class QuicEvent:
"""
Base class for QUIC events.
"""
pass
@dataclass
class ConnectionIdIssued(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionIdRetired(QuicEvent):
connection_id: bytes
@dataclass
class ConnectionTerminated(QuicEvent):
"""
The ConnectionTerminated event is fired when the QUIC connection is terminated.
"""
error_code: int
"The error code which was specified when closing the connection."
frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."
reason_phrase: str
"The human-readable reason for which the connection was closed."
@dataclass
class DatagramFrameReceived(QuicEvent):
"""
The DatagramFrameReceived event is fired when a DATAGRAM frame is received.
"""
data: bytes
"The data which was received."
@dataclass
class HandshakeCompleted(QuicEvent):
"""
The HandshakeCompleted event is fired when the TLS handshake completes.
"""
alpn_protocol: str | None
"The protocol which was negotiated using ALPN, or `None`."
early_data_accepted: bool
"Whether early (0-RTT) data was accepted by the remote peer."
session_resumed: bool
"Whether a TLS session was resumed."
@dataclass
class PingAcknowledged(QuicEvent):
"""
The PingAcknowledged event is fired when a PING frame is acknowledged.
"""
uid: int
"The unique ID of the PING."
@dataclass
class ProtocolNegotiated(QuicEvent):
"""
The ProtocolNegotiated event is fired when ALPN negotiation completes.
"""
alpn_protocol: str | None
"The protocol which was negotiated using ALPN, or `None`."
@dataclass
class StreamDataReceived(QuicEvent):
"""
The StreamDataReceived event is fired whenever data is received on a
stream.
"""
data: bytes
"The data which was received."
end_stream: bool
"Whether the STREAM frame had the FIN bit set."
stream_id: int
"The ID of the stream the data was received for."
@dataclass
class StopSendingReceived(QuicEvent):
"""
The StopSendingReceived event is fired when the remote peer requests
stopping data transmission on a stream.
"""
error_code: int
"The error code that was sent from the peer."
stream_id: int
"The ID of the stream that the peer requested stopping data transmission."
@dataclass
class StreamReset(QuicEvent):
"""
The StreamReset event is fired when the remote peer resets a stream.
"""
error_code: int
"The error code that triggered the reset."
stream_id: int
"The ID of the stream that was reset."

View File

@@ -0,0 +1,333 @@
from __future__ import annotations
import binascii
import json
import os
import time
from collections import deque
from typing import Any
from .._hazmat import RangeSet
from ..h3.events import Headers
from .packet import (
QuicFrameType,
QuicPacketType,
QuicStreamFrame,
QuicTransportParameters,
)
PACKET_TYPE_NAMES = {
QuicPacketType.INITIAL: "initial",
QuicPacketType.HANDSHAKE: "handshake",
QuicPacketType.ZERO_RTT: "0RTT",
QuicPacketType.ONE_RTT: "1RTT",
QuicPacketType.RETRY: "retry",
QuicPacketType.VERSION_NEGOTIATION: "version_negotiation",
}
QLOG_VERSION = "0.3"
def hexdump(data: bytes) -> str:
return binascii.hexlify(data).decode("ascii")
class QuicLoggerTrace:
"""
A QUIC event trace.
Events are logged in the format defined by qlog.
See:
- https://datatracker.ietf.org/doc/html/draft-ietf-quic-qlog-main-schema-02
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-quic-events
- https://datatracker.ietf.org/doc/html/draft-marx-quic-qlog-h3-events
"""
def __init__(self, *, is_client: bool, odcid: bytes) -> None:
self._odcid = odcid
self._events: deque[dict[str, Any]] = deque()
self._vantage_point = {
"name": "qh3",
"type": "client" if is_client else "server",
}
# QUIC
def encode_ack_frame(self, ranges: RangeSet, delay: float) -> dict:
return {
"ack_delay": self.encode_time(delay),
"acked_ranges": [[x[0], x[1] - 1] for x in ranges],
"frame_type": "ack",
}
def encode_connection_close_frame(
self, error_code: int, frame_type: int | None, reason_phrase: str
) -> dict:
attrs = {
"error_code": error_code,
"error_space": "application" if frame_type is None else "transport",
"frame_type": "connection_close",
"raw_error_code": error_code,
"reason": reason_phrase,
}
if frame_type is not None:
attrs["trigger_frame_type"] = frame_type
return attrs
def encode_connection_limit_frame(self, frame_type: int, maximum: int) -> dict:
if frame_type == QuicFrameType.MAX_DATA:
return {"frame_type": "max_data", "maximum": maximum}
else:
return {
"frame_type": "max_streams",
"maximum": maximum,
"stream_type": (
"unidirectional"
if frame_type == QuicFrameType.MAX_STREAMS_UNI
else "bidirectional"
),
}
def encode_crypto_frame(self, frame: QuicStreamFrame) -> dict:
return {
"frame_type": "crypto",
"length": len(frame.data),
"offset": frame.offset,
}
def encode_data_blocked_frame(self, limit: int) -> dict:
return {"frame_type": "data_blocked", "limit": limit}
def encode_datagram_frame(self, length: int) -> dict:
return {"frame_type": "datagram", "length": length}
def encode_handshake_done_frame(self) -> dict:
return {"frame_type": "handshake_done"}
def encode_max_stream_data_frame(self, maximum: int, stream_id: int) -> dict:
return {
"frame_type": "max_stream_data",
"maximum": maximum,
"stream_id": stream_id,
}
def encode_new_connection_id_frame(
self,
connection_id: bytes,
retire_prior_to: int,
sequence_number: int,
stateless_reset_token: bytes,
) -> dict:
return {
"connection_id": hexdump(connection_id),
"frame_type": "new_connection_id",
"length": len(connection_id),
"reset_token": hexdump(stateless_reset_token),
"retire_prior_to": retire_prior_to,
"sequence_number": sequence_number,
}
def encode_new_token_frame(self, token: bytes) -> dict:
return {
"frame_type": "new_token",
"length": len(token),
"token": hexdump(token),
}
def encode_padding_frame(self) -> dict:
return {"frame_type": "padding"}
def encode_path_challenge_frame(self, data: bytes) -> dict:
return {"data": hexdump(data), "frame_type": "path_challenge"}
def encode_path_response_frame(self, data: bytes) -> dict:
return {"data": hexdump(data), "frame_type": "path_response"}
def encode_ping_frame(self) -> dict:
return {"frame_type": "ping"}
def encode_reset_stream_frame(
self, error_code: int, final_size: int, stream_id: int
) -> dict:
return {
"error_code": error_code,
"final_size": final_size,
"frame_type": "reset_stream",
"stream_id": stream_id,
}
def encode_retire_connection_id_frame(self, sequence_number: int) -> dict:
return {
"frame_type": "retire_connection_id",
"sequence_number": sequence_number,
}
def encode_stream_data_blocked_frame(self, limit: int, stream_id: int) -> dict:
return {
"frame_type": "stream_data_blocked",
"limit": limit,
"stream_id": stream_id,
}
def encode_stop_sending_frame(self, error_code: int, stream_id: int) -> dict:
return {
"frame_type": "stop_sending",
"error_code": error_code,
"stream_id": stream_id,
}
def encode_stream_frame(self, frame: QuicStreamFrame, stream_id: int) -> dict:
return {
"fin": frame.fin,
"frame_type": "stream",
"length": len(frame.data),
"offset": frame.offset,
"stream_id": stream_id,
}
def encode_streams_blocked_frame(self, is_unidirectional: bool, limit: int) -> dict:
return {
"frame_type": "streams_blocked",
"limit": limit,
"stream_type": "unidirectional" if is_unidirectional else "bidirectional",
}
def encode_time(self, seconds: float) -> float:
"""
Convert a time to milliseconds.
"""
return seconds * 1000
def encode_transport_parameters(
self, owner: str, parameters: QuicTransportParameters
) -> dict[str, Any]:
data: dict[str, Any] = {"owner": owner}
for param_name, param_value in parameters.__dict__.items():
if isinstance(param_value, bool):
data[param_name] = param_value
elif isinstance(param_value, bytes):
data[param_name] = hexdump(param_value)
elif isinstance(param_value, int):
data[param_name] = param_value
return data
def packet_type(self, packet_type: QuicPacketType) -> str:
return PACKET_TYPE_NAMES[packet_type]
# HTTP/3
def encode_http3_data_frame(self, length: int, stream_id: int) -> dict:
return {
"frame": {"frame_type": "data"},
"length": length,
"stream_id": stream_id,
}
def encode_http3_headers_frame(
self, length: int, headers: Headers, stream_id: int
) -> dict:
return {
"frame": {
"frame_type": "headers",
"headers": self._encode_http3_headers(headers),
},
"length": length,
"stream_id": stream_id,
}
def encode_http3_push_promise_frame(
self, length: int, headers: Headers, push_id: int, stream_id: int
) -> dict:
return {
"frame": {
"frame_type": "push_promise",
"headers": self._encode_http3_headers(headers),
"push_id": push_id,
},
"length": length,
"stream_id": stream_id,
}
def _encode_http3_headers(self, headers: Headers) -> list[dict]:
return [
{"name": h[0].decode("utf8"), "value": h[1].decode("utf8")} for h in headers
]
# CORE
def log_event(self, *, category: str, event: str, data: dict) -> None:
self._events.append(
{
"data": data,
"name": category + ":" + event,
"time": self.encode_time(time.time()),
}
)
def to_dict(self) -> dict[str, Any]:
"""
Return the trace as a dictionary which can be written as JSON.
"""
return {
"common_fields": {
"ODCID": hexdump(self._odcid),
},
"events": list(self._events),
"vantage_point": self._vantage_point,
}
class QuicLogger:
"""
A QUIC event logger which stores traces in memory.
"""
def __init__(self) -> None:
self._traces: list[QuicLoggerTrace] = []
def start_trace(self, is_client: bool, odcid: bytes) -> QuicLoggerTrace:
trace = QuicLoggerTrace(is_client=is_client, odcid=odcid)
self._traces.append(trace)
return trace
def end_trace(self, trace: QuicLoggerTrace) -> None:
assert trace in self._traces, "QuicLoggerTrace does not belong to QuicLogger"
def to_dict(self) -> dict[str, Any]:
"""
Return the traces as a dictionary which can be written as JSON.
"""
return {
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace.to_dict() for trace in self._traces],
}
class QuicFileLogger(QuicLogger):
"""
A QUIC event logger which writes one trace per file.
"""
def __init__(self, path: str) -> None:
if not os.path.isdir(path):
raise ValueError(f"QUIC log output directory '{path}' does not exist")
self.path = path
super().__init__()
def end_trace(self, trace: QuicLoggerTrace) -> None:
trace_dict = trace.to_dict()
trace_path = os.path.join(
self.path, trace_dict["common_fields"]["ODCID"] + ".qlog"
)
with open(trace_path, "w") as logger_fp:
json.dump(
{
"qlog_format": "JSON",
"qlog_version": QLOG_VERSION,
"traces": [trace_dict],
},
logger_fp,
)
self._traces.remove(trace)

View File

@@ -0,0 +1,628 @@
from __future__ import annotations
import binascii
import ipaddress
import os
from dataclasses import dataclass
from enum import IntEnum
from .._compat import DATACLASS_KWARGS
from .._hazmat import AeadAes128Gcm, Buffer, RangeSet
PACKET_LONG_HEADER = 0x80
PACKET_FIXED_BIT = 0x40
PACKET_SPIN_BIT = 0x20
CONNECTION_ID_MAX_SIZE = 20
PACKET_NUMBER_MAX_SIZE = 4
RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e")
RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92")
RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb")
RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a")
RETRY_INTEGRITY_TAG_SIZE = 16
STATELESS_RESET_TOKEN_SIZE = 16
class QuicErrorCode(IntEnum):
NO_ERROR = 0x0
INTERNAL_ERROR = 0x1
CONNECTION_REFUSED = 0x2
FLOW_CONTROL_ERROR = 0x3
STREAM_LIMIT_ERROR = 0x4
STREAM_STATE_ERROR = 0x5
FINAL_SIZE_ERROR = 0x6
FRAME_ENCODING_ERROR = 0x7
TRANSPORT_PARAMETER_ERROR = 0x8
CONNECTION_ID_LIMIT_ERROR = 0x9
PROTOCOL_VIOLATION = 0xA
INVALID_TOKEN = 0xB
APPLICATION_ERROR = 0xC
CRYPTO_BUFFER_EXCEEDED = 0xD
KEY_UPDATE_ERROR = 0xE
AEAD_LIMIT_REACHED = 0xF
VERSION_NEGOTIATION_ERROR = 0x11
CRYPTO_ERROR = 0x100
class QuicPacketType(IntEnum):
INITIAL = 0
ZERO_RTT = 1
HANDSHAKE = 2
RETRY = 3
VERSION_NEGOTIATION = 4
ONE_RTT = 5
# For backwards compatibility only, use `QuicPacketType` in new code.
PACKET_TYPE_INITIAL = QuicPacketType.INITIAL
# QUIC version 1
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
PACKET_LONG_TYPE_ENCODE_VERSION_1 = {
QuicPacketType.INITIAL: 0,
QuicPacketType.ZERO_RTT: 1,
QuicPacketType.HANDSHAKE: 2,
QuicPacketType.RETRY: 3,
}
PACKET_LONG_TYPE_DECODE_VERSION_1 = {
v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items()
}
# QUIC version 2
# https://datatracker.ietf.org/doc/html/rfc9369#section-3.2
PACKET_LONG_TYPE_ENCODE_VERSION_2 = {
QuicPacketType.INITIAL: 1,
QuicPacketType.ZERO_RTT: 2,
QuicPacketType.HANDSHAKE: 3,
QuicPacketType.RETRY: 0,
}
PACKET_LONG_TYPE_DECODE_VERSION_2 = {
v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items()
}
class QuicProtocolVersion(IntEnum):
NEGOTIATION = 0
VERSION_1 = 0x00000001
VERSION_2 = 0x6B3343CF
@dataclass(**DATACLASS_KWARGS)
class QuicHeader:
version: int | None
"The protocol version. Only present in long header packets."
packet_type: QuicPacketType
"The type of the packet."
packet_length: int
"The total length of the packet, in bytes."
destination_cid: bytes
"The destination connection ID."
source_cid: bytes
"The destination connection ID."
token: bytes
"The address verification token. Only present in `INITIAL` and `RETRY` packets."
integrity_tag: bytes
"The retry integrity tag. Only present in `RETRY` packets."
supported_versions: list[int]
"Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets."
def get_retry_integrity_tag(
packet_without_tag: bytes, original_destination_cid: bytes, version: int
) -> bytes:
"""
Calculate the integrity tag for a RETRY packet.
"""
# build Retry pseudo packet
buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag))
buf.push_uint8(len(original_destination_cid))
buf.push_bytes(original_destination_cid)
buf.push_bytes(packet_without_tag)
assert buf.eof()
if version == QuicProtocolVersion.VERSION_2:
aead_key = RETRY_AEAD_KEY_VERSION_2
aead_nonce = RETRY_AEAD_NONCE_VERSION_2
else:
aead_key = RETRY_AEAD_KEY_VERSION_1
aead_nonce = RETRY_AEAD_NONCE_VERSION_1
# run AES-128-GCM
aead = AeadAes128Gcm(aead_key, b"null!12bytes")
integrity_tag = aead.encrypt_with_nonce(aead_nonce, b"", buf.data)
assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE
return integrity_tag
def get_spin_bit(first_byte: int) -> bool:
if first_byte & PACKET_SPIN_BIT:
return True
return False
def is_long_header(first_byte: int) -> bool:
if first_byte & PACKET_LONG_HEADER:
return True
return False
def pretty_protocol_version(version: int) -> str:
"""
Return a user-friendly representation of a protocol version.
"""
try:
version_name = QuicProtocolVersion(version).name
except ValueError:
version_name = "UNKNOWN"
return f"0x{version:08x} ({version_name})"
def pull_quic_header(buf: Buffer, host_cid_length: int | None = None) -> QuicHeader:
packet_start = buf.tell()
version: int | None
integrity_tag = b""
supported_versions = []
token = b""
first_byte = buf.pull_uint8()
if is_long_header(first_byte):
# Long Header Packets.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2
version = buf.pull_uint32()
destination_cid_length = buf.pull_uint8()
if destination_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(
f"Destination CID is too long ({destination_cid_length} bytes)"
)
destination_cid = buf.pull_bytes(destination_cid_length)
source_cid_length = buf.pull_uint8()
if source_cid_length > CONNECTION_ID_MAX_SIZE:
raise ValueError(f"Source CID is too long ({source_cid_length} bytes)")
source_cid = buf.pull_bytes(source_cid_length)
if version == QuicProtocolVersion.NEGOTIATION:
# Version Negotiation Packet.
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1
packet_type = QuicPacketType.VERSION_NEGOTIATION
while not buf.eof():
supported_versions.append(buf.pull_uint32())
packet_end = buf.tell()
else:
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
if version == QuicProtocolVersion.VERSION_2:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[
(first_byte & 0x30) >> 4
]
else:
packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[
(first_byte & 0x30) >> 4
]
if packet_type == QuicPacketType.INITIAL:
token_length = buf.pull_uint_var()
token = buf.pull_bytes(token_length)
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.ZERO_RTT:
rest_length = buf.pull_uint_var()
elif packet_type == QuicPacketType.HANDSHAKE:
rest_length = buf.pull_uint_var()
else:
token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE
token = buf.pull_bytes(token_length)
integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE)
rest_length = 0
# Check remainder length.
packet_end = buf.tell() + rest_length
if packet_end > buf.capacity:
raise ValueError("Packet payload is truncated")
else:
# https://datatracker.ietf.org/doc/html/rfc9000#section-17.3
if not (first_byte & PACKET_FIXED_BIT):
raise ValueError("Packet fixed bit is zero")
version = None
packet_type = QuicPacketType.ONE_RTT
destination_cid = buf.pull_bytes(host_cid_length)
source_cid = b""
packet_end = buf.capacity
return QuicHeader(
version=version,
packet_type=packet_type,
packet_length=packet_end - packet_start,
destination_cid=destination_cid,
source_cid=source_cid,
token=token,
integrity_tag=integrity_tag,
supported_versions=supported_versions,
)
def encode_long_header_first_byte(
version: int, packet_type: QuicPacketType, bits: int
) -> int:
"""
Encode the first byte of a long header packet.
"""
if version == QuicProtocolVersion.VERSION_2:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2
else:
long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1
return (
PACKET_LONG_HEADER
| PACKET_FIXED_BIT
| long_type_encode[packet_type] << 4
| bits
)
def encode_quic_retry(
version: int,
source_cid: bytes,
destination_cid: bytes,
original_destination_cid: bytes,
retry_token: bytes,
unused: int = 0,
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ len(retry_token)
+ RETRY_INTEGRITY_TAG_SIZE
)
buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused))
buf.push_uint32(version)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
buf.push_bytes(retry_token)
buf.push_bytes(
get_retry_integrity_tag(buf.data, original_destination_cid, version=version)
)
assert buf.eof()
return buf.data
def encode_quic_version_negotiation(
source_cid: bytes, destination_cid: bytes, supported_versions: list[int]
) -> bytes:
buf = Buffer(
capacity=7
+ len(destination_cid)
+ len(source_cid)
+ 4 * len(supported_versions)
)
buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER)
buf.push_uint32(QuicProtocolVersion.NEGOTIATION)
buf.push_uint8(len(destination_cid))
buf.push_bytes(destination_cid)
buf.push_uint8(len(source_cid))
buf.push_bytes(source_cid)
for version in supported_versions:
buf.push_uint32(version)
return buf.data
# TLS EXTENSION
@dataclass(**DATACLASS_KWARGS)
class QuicPreferredAddress:
ipv4_address: tuple[str, int] | None
ipv6_address: tuple[str, int] | None
connection_id: bytes
stateless_reset_token: bytes
@dataclass(**DATACLASS_KWARGS)
class QuicVersionInformation:
chosen_version: int
available_versions: list[int]
@dataclass()
class QuicTransportParameters:
original_destination_connection_id: bytes | None = None
max_idle_timeout: int | None = None
stateless_reset_token: bytes | None = None
max_udp_payload_size: int | None = None
initial_max_data: int | None = None
initial_max_stream_data_bidi_local: int | None = None
initial_max_stream_data_bidi_remote: int | None = None
initial_max_stream_data_uni: int | None = None
initial_max_streams_bidi: int | None = None
initial_max_streams_uni: int | None = None
ack_delay_exponent: int | None = None
max_ack_delay: int | None = None
disable_active_migration: bool | None = False
preferred_address: QuicPreferredAddress | None = None
active_connection_id_limit: int | None = None
initial_source_connection_id: bytes | None = None
retry_source_connection_id: bytes | None = None
version_information: QuicVersionInformation | None = None
max_datagram_frame_size: int | None = None
quantum_readiness: bytes | None = None
PARAMS = {
0x00: ("original_destination_connection_id", bytes),
0x01: ("max_idle_timeout", int),
0x02: ("stateless_reset_token", bytes),
0x03: ("max_udp_payload_size", int),
0x04: ("initial_max_data", int),
0x05: ("initial_max_stream_data_bidi_local", int),
0x06: ("initial_max_stream_data_bidi_remote", int),
0x07: ("initial_max_stream_data_uni", int),
0x08: ("initial_max_streams_bidi", int),
0x09: ("initial_max_streams_uni", int),
0x0A: ("ack_delay_exponent", int),
0x0B: ("max_ack_delay", int),
0x0C: ("disable_active_migration", bool),
0x0D: ("preferred_address", QuicPreferredAddress),
0x0E: ("active_connection_id_limit", int),
0x0F: ("initial_source_connection_id", bytes),
0x10: ("retry_source_connection_id", bytes),
# https://datatracker.ietf.org/doc/html/rfc9368#section-3
0x11: ("version_information", QuicVersionInformation),
# extensions
0x0020: ("max_datagram_frame_size", int),
0x0C37: ("quantum_readiness", bytes),
}
def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress:
ipv4_address = None
ipv4_host = buf.pull_bytes(4)
ipv4_port = buf.pull_uint16()
if ipv4_host != bytes(4):
ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port)
ipv6_address = None
ipv6_host = buf.pull_bytes(16)
ipv6_port = buf.pull_uint16()
if ipv6_host != bytes(16):
ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port)
connection_id_length = buf.pull_uint8()
connection_id = buf.pull_bytes(connection_id_length)
stateless_reset_token = buf.pull_bytes(16)
return QuicPreferredAddress(
ipv4_address=ipv4_address,
ipv6_address=ipv6_address,
connection_id=connection_id,
stateless_reset_token=stateless_reset_token,
)
def push_quic_preferred_address(
buf: Buffer, preferred_address: QuicPreferredAddress
) -> None:
if preferred_address.ipv4_address is not None:
buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed)
buf.push_uint16(preferred_address.ipv4_address[1])
else:
buf.push_bytes(bytes(6))
if preferred_address.ipv6_address is not None:
buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed)
buf.push_uint16(preferred_address.ipv6_address[1])
else:
buf.push_bytes(bytes(18))
buf.push_uint8(len(preferred_address.connection_id))
buf.push_bytes(preferred_address.connection_id)
buf.push_bytes(preferred_address.stateless_reset_token)
def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation:
chosen_version = buf.pull_uint32()
available_versions = []
for i in range(length // 4 - 1):
available_versions.append(buf.pull_uint32())
# If an endpoint receives a Chosen Version equal to zero, or any Available Version
# equal to zero, it MUST treat it as a parsing failure.
#
# https://datatracker.ietf.org/doc/html/rfc9368#section-4
if chosen_version == 0 or 0 in available_versions:
raise ValueError("Version Information must not contain version 0")
return QuicVersionInformation(
chosen_version=chosen_version,
available_versions=available_versions,
)
def push_quic_version_information(
buf: Buffer, version_information: QuicVersionInformation
) -> None:
buf.push_uint32(version_information.chosen_version)
for version in version_information.available_versions:
buf.push_uint32(version)
def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters:
params = QuicTransportParameters()
while not buf.eof():
param_id = buf.pull_uint_var()
param_len = buf.pull_uint_var()
param_start = buf.tell()
if param_id in PARAMS:
# parse known parameter
param_name, param_type = PARAMS[param_id]
if param_type is int:
setattr(params, param_name, buf.pull_uint_var())
elif param_type is bytes:
setattr(params, param_name, buf.pull_bytes(param_len))
elif param_type is QuicPreferredAddress:
setattr(params, param_name, pull_quic_preferred_address(buf))
elif param_type is QuicVersionInformation:
setattr(
params,
param_name,
pull_quic_version_information(buf, param_len),
)
else:
setattr(params, param_name, True)
else:
# skip unknown parameter
buf.pull_bytes(param_len)
if buf.tell() != param_start + param_len:
raise ValueError("Transport parameter length does not match")
return params
def push_quic_transport_parameters(
buf: Buffer, params: QuicTransportParameters
) -> None:
for param_id, (param_name, param_type) in PARAMS.items():
param_value = getattr(params, param_name)
if param_value is not None and param_value is not False:
param_buf = Buffer(capacity=65536)
if param_type is int:
param_buf.push_uint_var(param_value)
elif param_type is bytes:
param_buf.push_bytes(param_value)
elif param_type is QuicPreferredAddress:
push_quic_preferred_address(param_buf, param_value)
elif param_type is QuicVersionInformation:
push_quic_version_information(param_buf, param_value)
buf.push_uint_var(param_id)
buf.push_uint_var(param_buf.tell())
buf.push_bytes(param_buf.data)
# FRAMES
class QuicFrameType(IntEnum):
PADDING = 0x00
PING = 0x01
ACK = 0x02
ACK_ECN = 0x03
RESET_STREAM = 0x04
STOP_SENDING = 0x05
CRYPTO = 0x06
NEW_TOKEN = 0x07
STREAM_BASE = 0x08
MAX_DATA = 0x10
MAX_STREAM_DATA = 0x11
MAX_STREAMS_BIDI = 0x12
MAX_STREAMS_UNI = 0x13
DATA_BLOCKED = 0x14
STREAM_DATA_BLOCKED = 0x15
STREAMS_BLOCKED_BIDI = 0x16
STREAMS_BLOCKED_UNI = 0x17
NEW_CONNECTION_ID = 0x18
RETIRE_CONNECTION_ID = 0x19
PATH_CHALLENGE = 0x1A
PATH_RESPONSE = 0x1B
TRANSPORT_CLOSE = 0x1C
APPLICATION_CLOSE = 0x1D
HANDSHAKE_DONE = 0x1E
DATAGRAM = 0x30
DATAGRAM_WITH_LENGTH = 0x31
NON_ACK_ELICITING_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.PADDING,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
NON_IN_FLIGHT_FRAME_TYPES = frozenset(
[
QuicFrameType.ACK,
QuicFrameType.ACK_ECN,
QuicFrameType.TRANSPORT_CLOSE,
QuicFrameType.APPLICATION_CLOSE,
]
)
PROBING_FRAME_TYPES = frozenset(
[
QuicFrameType.PATH_CHALLENGE,
QuicFrameType.PATH_RESPONSE,
QuicFrameType.PADDING,
QuicFrameType.NEW_CONNECTION_ID,
]
)
@dataclass(**DATACLASS_KWARGS)
class QuicResetStreamFrame:
error_code: int
final_size: int
stream_id: int
@dataclass(**DATACLASS_KWARGS)
class QuicStopSendingFrame:
error_code: int
stream_id: int
@dataclass(**DATACLASS_KWARGS)
class QuicStreamFrame:
data: bytes = b""
fin: bool = False
offset: int = 0
def pull_ack_frame(buf: Buffer) -> tuple[RangeSet, int]:
rangeset = RangeSet()
end = buf.pull_uint_var() # largest acknowledged
delay = buf.pull_uint_var()
ack_range_count = buf.pull_uint_var()
ack_count = buf.pull_uint_var() # first ack range
rangeset.add(end - ack_count, end + 1)
end -= ack_count
for _ in range(ack_range_count):
end -= buf.pull_uint_var() + 2
ack_count = buf.pull_uint_var()
rangeset.add(end - ack_count, end + 1)
end -= ack_count
return rangeset, delay
def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int:
ranges = len(rangeset)
index = ranges - 1
r = rangeset[index]
buf.push_uint_var(r[1] - 1)
buf.push_uint_var(delay)
buf.push_uint_var(index)
buf.push_uint_var(r[1] - 1 - r[0])
start = r[0]
while index > 0:
index -= 1
r = rangeset[index]
buf.push_uint_var(start - r[1] - 1)
buf.push_uint_var(r[1] - r[0] - 1)
start = r[0]
return ranges

View File

@@ -0,0 +1,437 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Callable, Sequence
from .._compat import DATACLASS_KWARGS
from .._hazmat import Buffer, size_uint_var
from ..tls import Epoch
from .crypto import CryptoPair
from .logger import QuicLoggerTrace
from .packet import (
NON_ACK_ELICITING_FRAME_TYPES,
NON_IN_FLIGHT_FRAME_TYPES,
PACKET_FIXED_BIT,
PACKET_NUMBER_MAX_SIZE,
QuicFrameType,
QuicPacketType,
encode_long_header_first_byte,
)
# MinPacketSize and MaxPacketSize control the packet sizes for UDP datagrams.
# If MinPacketSize is unset, a default value of 1280 bytes
# will be used during the handshake.
# If MaxPacketSize is unset, a default value of 1452 bytes will be used.
# DPLPMTUD will automatically determine the MTU supported
# by the link-up to the MaxPacketSize,
# except for in the case where MinPacketSize and MaxPacketSize
# are configured to the same value,
# in which case path MTU discovery will be disabled.
# Values above 65355 are invalid.
# 20-bytes for IPv6 overhead.
# 1280 is very conservative
# Chrome tries 1350 at startup
# we should do a rudimentary MTU discovery
# Sending a PING frame 1350
# THEN 1452
SMALLEST_MAX_DATAGRAM_SIZE = 1200
PACKET_MAX_SIZE = 1280
MTU_PROBE_SIZES = [1350, 1452]
PACKET_LENGTH_SEND_SIZE = 2
PACKET_NUMBER_SEND_SIZE = 2
QuicDeliveryHandler = Callable[..., None]
class QuicDeliveryState(IntEnum):
ACKED = 0
LOST = 1
@dataclass(**DATACLASS_KWARGS)
class QuicSentPacket:
epoch: Epoch
in_flight: bool
is_ack_eliciting: bool
is_crypto_packet: bool
packet_number: int
packet_type: QuicPacketType
sent_time: float | None = None
sent_bytes: int = 0
delivery_handlers: list[tuple[QuicDeliveryHandler, Any]] = field(
default_factory=list
)
quic_logger_frames: list[dict] = field(default_factory=list)
class QuicPacketBuilderStop(Exception):
pass
class QuicPacketBuilder:
"""
Helper for building QUIC packets.
"""
__slots__ = (
"max_flight_bytes",
"max_total_bytes",
"quic_logger_frames",
"_host_cid",
"_is_client",
"_peer_cid",
"_peer_token",
"_quic_logger",
"_spin_bit",
"_version",
"_datagrams",
"_datagram_flight_bytes",
"_datagram_init",
"_datagram_needs_padding",
"_packets",
"_flight_bytes",
"_total_bytes",
"_header_size",
"_packet",
"_packet_crypto",
"_packet_long_header",
"_packet_number",
"_packet_start",
"_packet_type",
"_buffer",
"_buffer_capacity",
"_flight_capacity",
)
def __init__(
self,
*,
host_cid: bytes,
peer_cid: bytes,
version: int,
is_client: bool,
max_datagram_size: int = PACKET_MAX_SIZE,
packet_number: int = 0,
peer_token: bytes = b"",
quic_logger: QuicLoggerTrace | None = None,
spin_bit: bool = False,
):
self.max_flight_bytes: int | None = None
self.max_total_bytes: int | None = None
self.quic_logger_frames: list[dict] | None = None
self._host_cid = host_cid
self._is_client = is_client
self._peer_cid = peer_cid
self._peer_token = peer_token
self._quic_logger = quic_logger
self._spin_bit = spin_bit
self._version = version
# assembled datagrams and packets
self._datagrams: list[bytes] = []
self._datagram_flight_bytes = 0
self._datagram_init = True
self._datagram_needs_padding = False
self._packets: list[QuicSentPacket] = []
self._flight_bytes = 0
self._total_bytes = 0
# current packet
self._header_size = 0
self._packet: QuicSentPacket | None = None
self._packet_crypto: CryptoPair | None = None
self._packet_long_header = False
self._packet_number = packet_number
self._packet_start = 0
self._packet_type: QuicPacketType | None = None
self._buffer = Buffer(max_datagram_size)
self._buffer_capacity = max_datagram_size
self._flight_capacity = max_datagram_size
@property
def packet_is_empty(self) -> bool:
"""
Returns `True` if the current packet is empty.
"""
assert self._packet is not None
packet_size = self._buffer.tell() - self._packet_start
return packet_size <= self._header_size
@property
def packet_number(self) -> int:
"""
Returns the packet number for the next packet.
"""
return self._packet_number
@property
def remaining_buffer_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._buffer_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
@property
def remaining_flight_space(self) -> int:
"""
Returns the remaining number of bytes which can be used in
the current packet.
"""
return (
self._flight_capacity
- self._buffer.tell()
- self._packet_crypto.aead_tag_size
)
def flush(self) -> tuple[list[bytes], list[QuicSentPacket]]:
"""
Returns the assembled datagrams.
"""
if self._packet is not None:
self._end_packet()
self._flush_current_datagram()
datagrams = self._datagrams
packets = self._packets
self._datagrams = []
self._packets = []
return datagrams, packets
def start_frame(
self,
frame_type: int,
capacity: int = 1,
handler: QuicDeliveryHandler | None = None,
handler_args: Sequence[Any] = [],
) -> Buffer:
"""
Starts a new frame.
"""
if self.remaining_buffer_space < capacity or (
frame_type not in NON_IN_FLIGHT_FRAME_TYPES
and self.remaining_flight_space < capacity
):
raise QuicPacketBuilderStop
self._buffer.push_uint_var(frame_type)
if frame_type not in NON_ACK_ELICITING_FRAME_TYPES:
self._packet.is_ack_eliciting = True
if frame_type not in NON_IN_FLIGHT_FRAME_TYPES:
self._packet.in_flight = True
if frame_type == QuicFrameType.CRYPTO:
self._packet.is_crypto_packet = True
if handler is not None:
self._packet.delivery_handlers.append((handler, handler_args))
return self._buffer
def start_packet(self, packet_type: QuicPacketType, crypto: CryptoPair) -> None:
"""
Starts a new packet.
"""
assert packet_type not in {
QuicPacketType.RETRY,
QuicPacketType.VERSION_NEGOTIATION,
}, "Invalid packet type"
buf = self._buffer
# finish previous datagram
if self._packet is not None:
self._end_packet()
# if there is too little space remaining, start a new datagram
# FIXME: the limit is arbitrary!
packet_start = buf.tell()
if self._buffer_capacity - packet_start < 128:
self._flush_current_datagram()
packet_start = 0
# initialize datagram if needed
if self._datagram_init:
if self.max_total_bytes is not None:
remaining_total_bytes = self.max_total_bytes - self._total_bytes
if remaining_total_bytes < self._buffer_capacity:
self._buffer_capacity = remaining_total_bytes
self._flight_capacity = self._buffer_capacity
if self.max_flight_bytes is not None:
remaining_flight_bytes = self.max_flight_bytes - self._flight_bytes
if remaining_flight_bytes < self._flight_capacity:
self._flight_capacity = remaining_flight_bytes
self._datagram_flight_bytes = 0
self._datagram_init = False
self._datagram_needs_padding = False
# calculate header size
if packet_type != QuicPacketType.ONE_RTT:
header_size = 11 + len(self._peer_cid) + len(self._host_cid)
if packet_type == QuicPacketType.INITIAL:
token_length = len(self._peer_token)
header_size += size_uint_var(token_length) + token_length
else:
header_size = 3 + len(self._peer_cid)
# check we have enough space
if packet_start + header_size >= self._buffer_capacity:
raise QuicPacketBuilderStop
# determine ack epoch
if packet_type == QuicPacketType.INITIAL:
epoch = Epoch.INITIAL
elif packet_type == QuicPacketType.HANDSHAKE:
epoch = Epoch.HANDSHAKE
else:
epoch = Epoch.ONE_RTT
self._header_size = header_size
self._packet = QuicSentPacket(
epoch=epoch,
in_flight=False,
is_ack_eliciting=False,
is_crypto_packet=False,
packet_number=self._packet_number,
packet_type=packet_type,
)
self._packet_crypto = crypto
self._packet_start = packet_start
self._packet_type = packet_type
self.quic_logger_frames = self._packet.quic_logger_frames
buf.seek(self._packet_start + self._header_size)
def _end_packet(self) -> None:
"""
Ends the current packet.
"""
buf = self._buffer
packet_size = buf.tell() - self._packet_start
if packet_size > self._header_size:
# padding to ensure sufficient sample size
padding_size = (
PACKET_NUMBER_MAX_SIZE
- PACKET_NUMBER_SEND_SIZE
+ self._header_size
- packet_size
)
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if (
self._is_client or self._packet.is_ack_eliciting
) and self._packet_type == QuicPacketType.INITIAL:
self._datagram_needs_padding = True
# For datagrams containing 1-RTT data, we *must* apply the padding
# inside the packet, we cannot tack bytes onto the end of the
# datagram.
if (
self._datagram_needs_padding
and self._packet_type == QuicPacketType.ONE_RTT
):
if self.remaining_flight_space > padding_size:
padding_size = self.remaining_flight_space
self._datagram_needs_padding = False
# write padding
if padding_size > 0:
buf.push_bytes(bytes(padding_size))
packet_size += padding_size
self._packet.in_flight = True
# log frame
if self._quic_logger is not None:
self._packet.quic_logger_frames.append(
self._quic_logger.encode_padding_frame()
)
# write header
if self._packet_type != QuicPacketType.ONE_RTT:
length = (
packet_size
- self._header_size
+ PACKET_NUMBER_SEND_SIZE
+ self._packet_crypto.aead_tag_size
)
buf.seek(self._packet_start)
buf.push_uint8(
encode_long_header_first_byte(
self._version, self._packet_type, PACKET_NUMBER_SEND_SIZE - 1
)
)
buf.push_uint32(self._version)
buf.push_uint8(len(self._peer_cid))
buf.push_bytes(self._peer_cid)
buf.push_uint8(len(self._host_cid))
buf.push_bytes(self._host_cid)
if self._packet_type == QuicPacketType.INITIAL:
buf.push_uint_var(len(self._peer_token))
buf.push_bytes(self._peer_token)
buf.push_uint16(length | 0x4000)
buf.push_uint16(self._packet_number & 0xFFFF)
else:
buf.seek(self._packet_start)
buf.push_uint8(
PACKET_FIXED_BIT
| (self._spin_bit << 5)
| (self._packet_crypto.key_phase << 2)
| (PACKET_NUMBER_SEND_SIZE - 1)
)
buf.push_bytes(self._peer_cid)
buf.push_uint16(self._packet_number & 0xFFFF)
# encrypt in place
plain = buf.data_slice(self._packet_start, self._packet_start + packet_size)
buf.seek(self._packet_start)
buf.push_bytes(
self._packet_crypto.encrypt_packet(
plain[0 : self._header_size],
plain[self._header_size : packet_size],
self._packet_number,
)
)
self._packet.sent_bytes = buf.tell() - self._packet_start
self._packets.append(self._packet)
if self._packet.in_flight:
self._datagram_flight_bytes += self._packet.sent_bytes
# Short header packets cannot be coalesced, we need a new datagram.
if self._packet_type == QuicPacketType.ONE_RTT:
self._flush_current_datagram()
self._packet_number += 1
else:
# "cancel" the packet
buf.seek(self._packet_start)
self._packet = None
self.quic_logger_frames = None
def _flush_current_datagram(self) -> None:
datagram_bytes = self._buffer.tell()
if datagram_bytes:
# Padding for datagrams containing initial packets; see RFC 9000
# section 14.1.
if self._datagram_needs_padding:
extra_bytes = self._flight_capacity - self._buffer.tell()
if extra_bytes > 0:
self._buffer.push_bytes(bytes(extra_bytes))
self._datagram_flight_bytes += extra_bytes
datagram_bytes += extra_bytes
self._datagrams.append(self._buffer.data)
self._flight_bytes += self._datagram_flight_bytes
self._total_bytes += datagram_bytes
self._datagram_init = True
self._buffer.seek(0)

View File

@@ -0,0 +1,503 @@
from __future__ import annotations
import logging
import math
from typing import Any, Callable, Iterable
from .._hazmat import QuicPacketPacer, QuicRttMonitor, RangeSet
from .logger import QuicLoggerTrace
from .packet_builder import QuicDeliveryState, QuicSentPacket
# loss detection
K_PACKET_THRESHOLD = 3
K_GRANULARITY = 0.001 # seconds
K_TIME_THRESHOLD = 9 / 8
K_MICRO_SECOND = 0.000001
K_SECOND = 1.0
# congestion control
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
# Cubic constants (RFC 9438)
K_CUBIC_C = 0.4
K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7
K_CUBIC_MAX_IDLE_TIME = 2.0 # seconds
def _cubic_root(x: float) -> float:
if x < 0:
return -((-x) ** (1.0 / 3.0))
return x ** (1.0 / 3.0)
class QuicPacketSpace:
def __init__(self) -> None:
self.ack_at: float | None = None
self.ack_queue = RangeSet()
self.discarded = False
self.expected_packet_number = 0
self.largest_received_packet = -1
self.largest_received_time: float | None = None
# sent packets and loss
self.ack_eliciting_in_flight = 0
self.largest_acked_packet = 0
self.loss_time: float | None = None
self.sent_packets: dict[int, QuicSentPacket] = {}
class QuicCongestionControl:
"""
Cubic congestion control (RFC 9438).
"""
def __init__(self, max_datagram_size: int) -> None:
self._max_datagram_size = max_datagram_size
self._rtt_monitor = QuicRttMonitor()
self._congestion_recovery_start_time = 0.0
self._rtt = 0.02 # initial RTT estimate (20 ms)
self._last_ack = 0.0
self.bytes_in_flight = 0
self.congestion_window = max_datagram_size * K_INITIAL_WINDOW
self.ssthresh: int | None = None
# Cubic state
self._first_slow_start = True
self._starting_congestion_avoidance = False
self._K: float = 0.0
self._W_max: int = self.congestion_window
self._W_est: int = 0
self._cwnd_epoch: int = 0
self._t_epoch: float = 0.0
def _W_cubic(self, t: float) -> int:
W_max_segments = self._W_max / self._max_datagram_size
target_segments = K_CUBIC_C * (t - self._K) ** 3 + W_max_segments
return int(target_segments * self._max_datagram_size)
def _reset(self) -> None:
self.congestion_window = self._max_datagram_size * K_INITIAL_WINDOW
self.ssthresh = None
self._first_slow_start = True
self._starting_congestion_avoidance = False
self._K = 0.0
self._W_max = self.congestion_window
self._W_est = 0
self._cwnd_epoch = 0
self._t_epoch = 0.0
def _start_epoch(self, now: float) -> None:
self._t_epoch = now
self._cwnd_epoch = self.congestion_window
self._W_est = self._cwnd_epoch
W_max_seg = self._W_max / self._max_datagram_size
cwnd_seg = self._cwnd_epoch / self._max_datagram_size
self._K = _cubic_root((W_max_seg - cwnd_seg) / K_CUBIC_C)
def on_packet_acked(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight -= packet.sent_bytes
self._last_ack = packet.sent_time
if self.ssthresh is None or self.congestion_window < self.ssthresh:
# slow start
self.congestion_window += packet.sent_bytes
else:
# congestion avoidance
if self._first_slow_start and not self._starting_congestion_avoidance:
# exiting slow start without a loss (HyStart triggered)
self._first_slow_start = False
self._W_max = self.congestion_window
self._start_epoch(packet.sent_time)
if self._starting_congestion_avoidance:
# entering congestion avoidance after a loss
self._starting_congestion_avoidance = False
self._first_slow_start = False
self._start_epoch(packet.sent_time)
# TCP-friendly estimate (Reno-like linear growth)
self._W_est = int(
self._W_est
+ self._max_datagram_size * (packet.sent_bytes / self.congestion_window)
)
t = packet.sent_time - self._t_epoch
W_cubic = self._W_cubic(t + self._rtt)
# clamp target
if W_cubic < self.congestion_window:
target = self.congestion_window
elif W_cubic > int(1.5 * self.congestion_window):
target = int(1.5 * self.congestion_window)
else:
target = W_cubic
if self._W_cubic(t) < self._W_est:
# Reno-friendly region
self.congestion_window = self._W_est
else:
# concave / convex region
self.congestion_window = int(
self.congestion_window
+ (target - self.congestion_window)
* (self._max_datagram_size / self.congestion_window)
)
def on_packet_sent(self, packet: QuicSentPacket) -> None:
self.bytes_in_flight += packet.sent_bytes
# reset cwnd after prolonged idle
if self._last_ack > 0.0:
elapsed_idle = packet.sent_time - self._last_ack
if elapsed_idle >= K_CUBIC_MAX_IDLE_TIME:
self._reset()
def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
lost_largest_time = 0.0
for packet in packets:
self.bytes_in_flight -= packet.sent_bytes
lost_largest_time = packet.sent_time
# start a new congestion event if packet was sent after the
# start of the previous congestion recovery period.
if lost_largest_time > self._congestion_recovery_start_time:
self._congestion_recovery_start_time = now
# fast convergence: if W_max is decreasing, reduce it further
if self.congestion_window < self._W_max:
self._W_max = int(
self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2
)
else:
self._W_max = self.congestion_window
self.congestion_window = max(
int(self.congestion_window * K_CUBIC_LOSS_REDUCTION_FACTOR),
self._max_datagram_size * K_MINIMUM_WINDOW,
)
self.ssthresh = self.congestion_window
self._starting_congestion_avoidance = True
def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
self._rtt = latest_rtt
# check whether we should exit slow start
if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing(
latest_rtt, now
):
self.ssthresh = self.congestion_window
class QuicPacketRecovery:
"""
Packet loss and congestion controller.
"""
def __init__(
self,
initial_rtt: float,
peer_completed_address_validation: bool,
send_probe: Callable[[], None],
max_datagram_size: int = 1280,
logger: logging.LoggerAdapter | None = None,
quic_logger: QuicLoggerTrace | None = None,
) -> None:
self.max_ack_delay = 0.025
self.peer_completed_address_validation = peer_completed_address_validation
self.spaces: list[QuicPacketSpace] = []
# callbacks
self._logger = logger
self._quic_logger = quic_logger
self._send_probe = send_probe
# loss detection
self._pto_count = 0
self._rtt_initial = initial_rtt
self._rtt_initialized = False
self._rtt_latest = 0.0
self._rtt_min = math.inf
self._rtt_smoothed = 0.0
self._rtt_variance = 0.0
self._time_of_last_sent_ack_eliciting_packet = 0.0
# congestion control
self._cc = QuicCongestionControl(max_datagram_size)
self._pacer = QuicPacketPacer(max_datagram_size)
@property
def bytes_in_flight(self) -> int:
return self._cc.bytes_in_flight
@property
def congestion_window(self) -> int:
return self._cc.congestion_window
def discard_space(self, space: QuicPacketSpace) -> None:
assert space in self.spaces
self._cc.on_packets_expired(
filter(lambda x: x.in_flight, space.sent_packets.values())
)
space.sent_packets.clear()
space.ack_at = None
space.ack_eliciting_in_flight = 0
space.loss_time = None
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated()
def get_loss_detection_time(self) -> float:
# loss timer
loss_space = self._get_loss_space()
if loss_space is not None:
return loss_space.loss_time
# packet timer
if (
not self.peer_completed_address_validation
or sum(space.ack_eliciting_in_flight for space in self.spaces) > 0
):
timeout = self.get_probe_timeout() * (2**self._pto_count)
return self._time_of_last_sent_ack_eliciting_packet + timeout
return None
def get_probe_timeout(self) -> float:
if not self._rtt_initialized:
return 2 * self._rtt_initial
return (
self._rtt_smoothed
+ max(4 * self._rtt_variance, K_GRANULARITY)
+ self.max_ack_delay
)
def on_ack_received(
self,
space: QuicPacketSpace,
ack_rangeset: RangeSet,
ack_delay: float,
now: float,
) -> None:
"""
Update metrics as the result of an ACK being received.
"""
is_ack_eliciting = False
largest_acked = ack_rangeset.bounds()[1] - 1
largest_newly_acked = None
largest_sent_time = None
if largest_acked > space.largest_acked_packet:
space.largest_acked_packet = largest_acked
for packet_number in sorted(space.sent_packets.keys()):
if packet_number > largest_acked:
break
if packet_number in ack_rangeset:
# remove packet and update counters
packet = space.sent_packets.pop(packet_number)
if packet.is_ack_eliciting:
is_ack_eliciting = True
space.ack_eliciting_in_flight -= 1
if packet.in_flight:
self._cc.on_packet_acked(packet)
largest_newly_acked = packet_number
largest_sent_time = packet.sent_time
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.ACKED, *args)
# nothing to do if there are no newly acked packets
if largest_newly_acked is None:
return
if largest_acked == largest_newly_acked and is_ack_eliciting:
latest_rtt = now - largest_sent_time
log_rtt = True
# limit ACK delay to max_ack_delay
ack_delay = min(ack_delay, self.max_ack_delay)
# update RTT estimate, which cannot be < 1 ms
self._rtt_latest = max(latest_rtt, 0.001)
if self._rtt_latest < self._rtt_min:
self._rtt_min = self._rtt_latest
if self._rtt_latest > self._rtt_min + ack_delay:
self._rtt_latest -= ack_delay
if not self._rtt_initialized:
self._rtt_initialized = True
self._rtt_variance = latest_rtt / 2
self._rtt_smoothed = latest_rtt
else:
self._rtt_variance = 3 / 4 * self._rtt_variance + 1 / 4 * abs(
self._rtt_min - self._rtt_latest
)
self._rtt_smoothed = (
7 / 8 * self._rtt_smoothed + 1 / 8 * self._rtt_latest
)
# inform congestion controller
self._cc.on_rtt_measurement(latest_rtt, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
else:
log_rtt = False
self._detect_loss(space, now=now)
# reset PTO count
self._pto_count = 0
if self._quic_logger is not None:
self._log_metrics_updated(log_rtt=log_rtt)
def on_loss_detection_timeout(self, now: float) -> None:
loss_space = self._get_loss_space()
if loss_space is not None:
self._detect_loss(loss_space, now=now)
else:
self._pto_count += 1
self.reschedule_data(now=now)
def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None:
space.sent_packets[packet.packet_number] = packet
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight += 1
if packet.in_flight:
if packet.is_ack_eliciting:
self._time_of_last_sent_ack_eliciting_packet = packet.sent_time
# add packet to bytes in flight
self._cc.on_packet_sent(packet)
if self._quic_logger is not None:
self._log_metrics_updated()
def reschedule_data(self, now: float) -> None:
"""
Schedule some data for retransmission.
"""
# if there is any outstanding CRYPTO, retransmit it
crypto_scheduled = False
for space in self.spaces:
packets = tuple(
filter(lambda i: i.is_crypto_packet, space.sent_packets.values())
)
if packets:
self._on_packets_lost(packets, space=space, now=now)
crypto_scheduled = True
if crypto_scheduled and self._logger is not None:
self._logger.debug("Scheduled CRYPTO data for retransmission")
# ensure an ACK-elliciting packet is sent
self._send_probe()
def _detect_loss(self, space: QuicPacketSpace, now: float) -> None:
"""
Check whether any packets should be declared lost.
"""
loss_delay = K_TIME_THRESHOLD * (
max(self._rtt_latest, self._rtt_smoothed)
if self._rtt_initialized
else self._rtt_initial
)
packet_threshold = space.largest_acked_packet - K_PACKET_THRESHOLD
time_threshold = now - loss_delay
lost_packets = []
space.loss_time = None
for packet_number, packet in space.sent_packets.items():
if packet_number > space.largest_acked_packet:
break
if packet_number <= packet_threshold or packet.sent_time <= time_threshold:
lost_packets.append(packet)
else:
packet_loss_time = packet.sent_time + loss_delay
if space.loss_time is None or space.loss_time > packet_loss_time:
space.loss_time = packet_loss_time
self._on_packets_lost(lost_packets, space=space, now=now)
def _get_loss_space(self) -> QuicPacketSpace | None:
loss_space = None
for space in self.spaces:
if space.loss_time is not None and (
loss_space is None or space.loss_time < loss_space.loss_time
):
loss_space = space
return loss_space
def _log_metrics_updated(self, log_rtt=False) -> None:
data: dict[str, Any] = {
"bytes_in_flight": self._cc.bytes_in_flight,
"cwnd": self._cc.congestion_window,
}
if self._cc.ssthresh is not None:
data["ssthresh"] = self._cc.ssthresh
if log_rtt:
data.update(
{
"latest_rtt": self._quic_logger.encode_time(self._rtt_latest),
"min_rtt": self._quic_logger.encode_time(self._rtt_min),
"smoothed_rtt": self._quic_logger.encode_time(self._rtt_smoothed),
"rtt_variance": self._quic_logger.encode_time(self._rtt_variance),
}
)
self._quic_logger.log_event(
category="recovery", event="metrics_updated", data=data
)
def _on_packets_lost(
self, packets: Iterable[QuicSentPacket], space: QuicPacketSpace, now: float
) -> None:
lost_packets_cc = []
for packet in packets:
del space.sent_packets[packet.packet_number]
if packet.in_flight:
lost_packets_cc.append(packet)
if packet.is_ack_eliciting:
space.ack_eliciting_in_flight -= 1
if self._quic_logger is not None:
self._quic_logger.log_event(
category="recovery",
event="packet_lost",
data={
"type": self._quic_logger.packet_type(packet.packet_type),
"packet_number": packet.packet_number,
},
)
self._log_metrics_updated()
# trigger callbacks
for handler, args in packet.delivery_handlers:
handler(QuicDeliveryState.LOST, *args)
# inform congestion controller
if lost_packets_cc:
self._cc.on_packets_lost(lost_packets_cc, now=now)
self._pacer.update_rate(
congestion_window=self._cc.congestion_window,
smoothed_rtt=self._rtt_smoothed,
)
if self._quic_logger is not None:
self._log_metrics_updated()

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import ipaddress
from .._hazmat import Buffer, Rsa
from ..tls import pull_opaque, push_opaque
from .connection import NetworkAddress
def encode_address(addr: NetworkAddress) -> bytes:
return ipaddress.ip_address(addr[0]).packed + bytes([addr[1] >> 8, addr[1] & 0xFF])
class QuicRetryTokenHandler:
def __init__(self) -> None:
self._key = Rsa(key_size=2048)
def create_token(
self,
addr: NetworkAddress,
original_destination_connection_id: bytes,
retry_source_connection_id: bytes,
) -> bytes:
buf = Buffer(capacity=512)
push_opaque(buf, 1, encode_address(addr))
push_opaque(buf, 1, original_destination_connection_id)
push_opaque(buf, 1, retry_source_connection_id)
return self._key.encrypt(buf.data)
def validate_token(self, addr: NetworkAddress, token: bytes) -> tuple[bytes, bytes]:
if not token or len(token) != 256:
raise ValueError("Ciphertext length must be equal to key size.")
buf = Buffer(data=self._key.decrypt(token))
encoded_addr = pull_opaque(buf, 1)
original_destination_connection_id = pull_opaque(buf, 1)
retry_source_connection_id = pull_opaque(buf, 1)
if encoded_addr != encode_address(addr):
raise ValueError("Remote address does not match.")
return original_destination_connection_id, retry_source_connection_id

View File

@@ -0,0 +1,380 @@
from __future__ import annotations
from .._hazmat import RangeSet
from . import events
from .packet import (
QuicErrorCode,
QuicResetStreamFrame,
QuicStopSendingFrame,
QuicStreamFrame,
)
from .packet_builder import QuicDeliveryState
class FinalSizeError(Exception):
pass
class StreamFinishedError(Exception):
pass
class QuicStreamReceiver:
"""
The receive part of a QUIC stream.
It finishes:
- immediately for a send-only stream
- upon reception of a STREAM_RESET frame
- upon reception of a data frame with the FIN bit set
"""
__slots__ = (
"highest_offset",
"is_finished",
"stop_pending",
"_buffer",
"_buffer_start",
"_final_size",
"_ranges",
"_stream_id",
"_stop_error_code",
)
def __init__(self, stream_id: int | None, readable: bool) -> None:
self.highest_offset = 0 # the highest offset ever seen
self.is_finished = False
self.stop_pending = False
self._buffer = bytearray()
self._buffer_start = 0 # the offset for the start of the buffer
self._final_size: int | None = None
self._ranges = RangeSet()
self._stream_id = stream_id
self._stop_error_code: int | None = None
def get_stop_frame(self) -> QuicStopSendingFrame:
self.stop_pending = False
return QuicStopSendingFrame(
error_code=self._stop_error_code,
stream_id=self._stream_id,
)
def starting_offset(self) -> int:
return self._buffer_start
def handle_frame(self, frame: QuicStreamFrame) -> events.StreamDataReceived | None:
"""
Handle a frame of received data.
"""
pos = frame.offset - self._buffer_start
count = len(frame.data)
frame_end = frame.offset + count
# we should receive no more data beyond FIN!
if self._final_size is not None:
if frame_end > self._final_size:
raise FinalSizeError("Data received beyond final size")
elif frame.fin and frame_end != self._final_size:
raise FinalSizeError("Cannot change final size")
if frame.fin:
self._final_size = frame_end
if frame_end > self.highest_offset:
self.highest_offset = frame_end
# fast path: new in-order chunk
if pos == 0 and count and not self._buffer:
self._buffer_start += count
if frame.fin:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
return events.StreamDataReceived(
data=frame.data, end_stream=frame.fin, stream_id=self._stream_id
)
# discard duplicate data
if pos < 0:
frame.data = frame.data[-pos:]
frame.offset -= pos
pos = 0
count = len(frame.data)
# marked received range
if frame_end > frame.offset:
self._ranges.add(frame.offset, frame_end)
# add new data
gap = pos - len(self._buffer)
if gap > 0:
self._buffer += bytearray(gap)
self._buffer[pos : pos + count] = frame.data
# return data from the front of the buffer
data = self._pull_data()
end_stream = self._buffer_start == self._final_size
if end_stream:
# all data up to the FIN has been received, we're done receiving
self.is_finished = True
if data or end_stream:
return events.StreamDataReceived(
data=data, end_stream=end_stream, stream_id=self._stream_id
)
else:
return None
def handle_reset(
self, *, final_size: int, error_code: int = QuicErrorCode.NO_ERROR
) -> events.StreamReset | None:
"""
Handle an abrupt termination of the receiving part of the QUIC stream.
"""
if self._final_size is not None and final_size != self._final_size:
raise FinalSizeError("Cannot change final size")
# we are done receiving
self._final_size = final_size
self.is_finished = True
return events.StreamReset(error_code=error_code, stream_id=self._stream_id)
def on_stop_sending_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a STOP_SENDING is ACK'd.
"""
if delivery != QuicDeliveryState.ACKED:
self.stop_pending = True
def stop(self, error_code: int = QuicErrorCode.NO_ERROR) -> None:
"""
Request the peer stop sending data on the QUIC stream.
"""
self._stop_error_code = error_code
self.stop_pending = True
def _pull_data(self) -> bytes:
"""
Remove data from the front of the buffer.
"""
try:
has_data_to_read = self._ranges[0][0] == self._buffer_start
except IndexError:
has_data_to_read = False
if not has_data_to_read:
return b""
r = self._ranges.shift()
pos = r[1] - r[0]
data = bytes(self._buffer[:pos])
del self._buffer[:pos]
self._buffer_start = r[1]
return data
class QuicStreamSender:
"""
The send part of a QUIC stream.
It finishes:
- immediately for a receive-only stream
- upon acknowledgement of a STREAM_RESET frame
- upon acknowledgement of a data frame with the FIN bit set
"""
__slots__ = (
"buffer_is_empty",
"highest_offset",
"is_finished",
"reset_pending",
"_acked",
"_buffer",
"_buffer_fin",
"_buffer_start",
"_buffer_stop",
"_pending",
"_pending_eof",
"_reset_error_code",
"_stream_id",
"send_buffer_empty",
)
def __init__(self, stream_id: int | None, writable: bool) -> None:
self.buffer_is_empty = True
self.highest_offset = 0
self.is_finished = not writable
self.reset_pending = False
self._acked = RangeSet()
self._buffer = bytearray()
self._buffer_fin: int | None = None
self._buffer_start = 0 # the offset for the start of the buffer
self._buffer_stop = 0 # the offset for the stop of the buffer
self._pending = RangeSet()
self._pending_eof = False
self._reset_error_code: int | None = None
self._stream_id = stream_id
@property
def next_offset(self) -> int:
"""
The offset for the next frame to send.
This is used to determine the space needed for the frame's `offset` field.
"""
try:
return self._pending[0][0]
except IndexError:
return self._buffer_stop
def get_frame(
self, max_size: int, max_offset: int | None = None
) -> QuicStreamFrame | None:
"""
Get a frame of data to send.
"""
# get the first pending data range
try:
r = self._pending[0]
except IndexError:
if self._pending_eof:
# FIN only
self._pending_eof = False
return QuicStreamFrame(fin=True, offset=self._buffer_fin)
self.buffer_is_empty = True
return None
# apply flow control
start = r[0]
stop = min(r[1], start + max_size)
if max_offset is not None and stop > max_offset:
stop = max_offset
if stop <= start:
return None
# create frame
frame = QuicStreamFrame(
data=bytes(
self._buffer[start - self._buffer_start : stop - self._buffer_start]
),
offset=start,
)
self._pending.subtract(start, stop)
# track the highest offset ever sent
if stop > self.highest_offset:
self.highest_offset = stop
# if the buffer is empty and EOF was written, set the FIN bit
if self._buffer_fin == stop:
frame.fin = True
self._pending_eof = False
return frame
def get_reset_frame(self) -> QuicResetStreamFrame:
self.reset_pending = False
return QuicResetStreamFrame(
error_code=self._reset_error_code,
final_size=self.highest_offset,
stream_id=self._stream_id,
)
def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int
) -> None:
"""
Callback when sent data is ACK'd.
"""
self.buffer_is_empty = False
if delivery == QuicDeliveryState.ACKED:
if stop > start:
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range[0] == self._buffer_start:
size = first_range[1] - first_range[0]
self._acked.shift()
self._buffer_start += size
del self._buffer[:size]
if self._buffer_start == self._buffer_fin:
# all date up to the FIN has been ACK'd, we're done sending
self.is_finished = True
else:
if stop > start:
self._pending.add(start, stop)
if stop == self._buffer_fin:
self.send_buffer_empty = False
self._pending_eof = True
def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
"""
Callback when a reset is ACK'd.
"""
if delivery == QuicDeliveryState.ACKED:
# the reset has been ACK'd, we're done sending
self.is_finished = True
else:
self.reset_pending = True
def reset(self, error_code: int) -> None:
"""
Abruptly terminate the sending part of the QUIC stream.
Once this method has been called, any further calls to it
will have no effect.
"""
if self._reset_error_code is None:
self._reset_error_code = error_code
self.reset_pending = True
# Prevent any more data from being sent or re-sent.
self.buffer_is_empty = True
def write(self, data: bytes, end_stream: bool = False) -> None:
"""
Write some data bytes to the QUIC stream.
"""
assert self._buffer_fin is None, "cannot call write() after FIN"
assert self._reset_error_code is None, "cannot call write() after reset()"
size = len(data)
if size:
self.buffer_is_empty = False
self._pending.add(self._buffer_stop, self._buffer_stop + size)
self._buffer += data
self._buffer_stop += size
if end_stream:
self.buffer_is_empty = False
self._buffer_fin = self._buffer_stop
self._pending_eof = True
class QuicStream:
__slots__ = (
"is_blocked",
"max_stream_data_local",
"max_stream_data_local_sent",
"max_stream_data_remote",
"receiver",
"sender",
"stream_id",
)
def __init__(
self,
stream_id: int | None = None,
max_stream_data_local: int = 0,
max_stream_data_remote: int = 0,
readable: bool = True,
writable: bool = True,
) -> None:
self.is_blocked = False
self.max_stream_data_local = max_stream_data_local
self.max_stream_data_local_sent = max_stream_data_local
self.max_stream_data_remote = max_stream_data_remote
self.receiver = QuicStreamReceiver(stream_id=stream_id, readable=readable)
self.sender = QuicStreamSender(stream_id=stream_id, writable=writable)
self.stream_id = stream_id
@property
def is_finished(self) -> bool:
return self.receiver.is_finished and self.sender.is_finished

File diff suppressed because it is too large Load Diff