from __future__ import annotations import binascii import logging import os from collections import deque from dataclasses import dataclass from enum import IntEnum from functools import lru_cache, partial from typing import TYPE_CHECKING, Any, Sequence if TYPE_CHECKING: from .configuration import QuicConfiguration from .logger import QuicLoggerTrace from .. import tls from .._compat import DATACLASS_KWARGS, UINT_VAR_MAX, UINT_VAR_MAX_SIZE from .._hazmat import Buffer, BufferReadError, size_uint_var from .._hazmat import Certificate as X509Certificate from . import events from .crypto import CryptoError, CryptoPair, KeyUnavailableError, NoCallback from .packet import ( CONNECTION_ID_MAX_SIZE, NON_ACK_ELICITING_FRAME_TYPES, PROBING_FRAME_TYPES, RETRY_INTEGRITY_TAG_SIZE, STATELESS_RESET_TOKEN_SIZE, QuicErrorCode, QuicFrameType, QuicHeader, QuicPacketType, QuicProtocolVersion, QuicStreamFrame, QuicTransportParameters, QuicVersionInformation, get_retry_integrity_tag, get_spin_bit, pretty_protocol_version, pull_ack_frame, pull_quic_header, pull_quic_transport_parameters, push_ack_frame, push_quic_transport_parameters, ) from .packet_builder import ( MTU_PROBE_SIZES, SMALLEST_MAX_DATAGRAM_SIZE, QuicDeliveryState, QuicPacketBuilder, QuicPacketBuilderStop, ) from .recovery import K_GRANULARITY, QuicPacketRecovery, QuicPacketSpace from .stream import FinalSizeError, QuicStream, StreamFinishedError logger = logging.getLogger("quic") CRYPTO_BUFFER_SIZE = 16384 EPOCH_SHORTCUTS = { "I": tls.Epoch.INITIAL, "H": tls.Epoch.HANDSHAKE, "0": tls.Epoch.ZERO_RTT, "1": tls.Epoch.ONE_RTT, } MAX_EARLY_DATA = 0xFFFFFFFF MAX_REMOTE_CHALLENGES = 32 MAX_LOCAL_CHALLENGES = 5 MAX_PENDING_RETIRES = 100 SECRETS_LABELS = [ [ None, "CLIENT_EARLY_TRAFFIC_SECRET", "CLIENT_HANDSHAKE_TRAFFIC_SECRET", "CLIENT_TRAFFIC_SECRET_0", ], [ None, None, "SERVER_HANDSHAKE_TRAFFIC_SECRET", "SERVER_TRAFFIC_SECRET_0", ], ] STREAM_FLAGS = 0x07 STREAM_COUNT_MAX = 0x1000000000000000 UDP_HEADER_SIZE = 8 MAX_PENDING_CRYPTO = 524288 # in bytes NetworkAddress = Any # frame sizes ACK_FRAME_CAPACITY = 64 # FIXME: this is arbitrary! APPLICATION_CLOSE_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE # + reason length CONNECTION_LIMIT_FRAME_CAPACITY = 1 + UINT_VAR_MAX_SIZE HANDSHAKE_DONE_FRAME_CAPACITY = 1 MAX_STREAM_DATA_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE NEW_CONNECTION_ID_FRAME_CAPACITY = ( 1 + 2 * UINT_VAR_MAX_SIZE + 1 + CONNECTION_ID_MAX_SIZE + STATELESS_RESET_TOKEN_SIZE ) PATH_CHALLENGE_FRAME_CAPACITY = 1 + 8 PATH_RESPONSE_FRAME_CAPACITY = 1 + 8 PING_FRAME_CAPACITY = 1 RESET_STREAM_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE RETIRE_CONNECTION_ID_CAPACITY = 1 + UINT_VAR_MAX_SIZE STOP_SENDING_FRAME_CAPACITY = 1 + 2 * UINT_VAR_MAX_SIZE STREAMS_BLOCKED_CAPACITY = 1 + UINT_VAR_MAX_SIZE TRANSPORT_CLOSE_FRAME_CAPACITY = 1 + 3 * UINT_VAR_MAX_SIZE # + reason length def EPOCHS(shortcut: str) -> frozenset[tls.Epoch]: return frozenset(EPOCH_SHORTCUTS[i] for i in shortcut) def is_version_compatible(from_version: int, to_version: int) -> bool: """ Return whether it is possible to perform compatible version negotiation from `from_version` to `to_version`. """ # Version 1 is compatible with version 2 and vice versa. These are the # only compatible versions so far. return {from_version, to_version} == { QuicProtocolVersion.VERSION_1, QuicProtocolVersion.VERSION_2, } def dump_cid(cid: bytes) -> str: return binascii.hexlify(cid).decode("ascii") def get_epoch(packet_type: QuicPacketType) -> tls.Epoch: if packet_type == QuicPacketType.INITIAL: return tls.Epoch.INITIAL elif packet_type == QuicPacketType.ZERO_RTT: return tls.Epoch.ZERO_RTT elif packet_type == QuicPacketType.HANDSHAKE: return tls.Epoch.HANDSHAKE else: return tls.Epoch.ONE_RTT def stream_is_client_initiated(stream_id: int) -> bool: """ Returns True if the stream is client initiated. """ return not (stream_id & 1) def stream_is_unidirectional(stream_id: int) -> bool: """ Returns True if the stream is unidirectional. """ if stream_id & 2: return True return False @lru_cache() def check_stream_id_for_sending(is_client: bool, stream_id: int) -> bool: return stream_is_client_initiated( stream_id ) is is_client or not stream_is_unidirectional(stream_id) @lru_cache() def check_stream_id_for_receiving(is_client: bool, stream_id: int) -> bool: return stream_is_client_initiated( stream_id ) is not is_client or not stream_is_unidirectional(stream_id) class Limit: def __init__(self, frame_type: int, name: str, value: int): self.frame_type = frame_type self.name = name self.sent = value self.used = 0 self.value = value class QuicConnectionError(Exception): def __init__(self, error_code: int, frame_type: int | None, reason_phrase: str): self.error_code = error_code self.frame_type = frame_type self.reason_phrase = reason_phrase def __str__(self) -> str: s = f"Error: {self.error_code}, reason: {self.reason_phrase}" if self.frame_type is not None: s += f", frame_type: {self.frame_type}" return s class QuicConnectionAdapter(logging.LoggerAdapter): def process(self, msg: str, kwargs: Any) -> tuple[str, Any]: return "[{}] {}".format(self.extra["id"], msg), kwargs @dataclass(**DATACLASS_KWARGS) class QuicConnectionId: cid: bytes sequence_number: int stateless_reset_token: bytes = b"" was_sent: bool = False class QuicConnectionState(IntEnum): FIRSTFLIGHT = 0 CONNECTED = 1 CLOSING = 2 DRAINING = 3 TERMINATED = 4 class QuicNetworkPath: def __init__(self, addr: NetworkAddress, is_validated: bool = False): self.addr: NetworkAddress = addr self.bytes_received: int = 0 self.bytes_sent: int = 0 self.is_validated: bool = is_validated self.local_challenge_sent: bool = False self.remote_challenges: deque[bytes] = deque() def can_send(self, size: int) -> bool: return self.is_validated or (self.bytes_sent + size) <= 3 * self.bytes_received @dataclass(**DATACLASS_KWARGS) class QuicReceiveContext: epoch: tls.Epoch host_cid: bytes network_path: QuicNetworkPath quic_logger_frames: list[Any] | None time: float version: int | None END_STATES = frozenset( [ QuicConnectionState.CLOSING, QuicConnectionState.DRAINING, QuicConnectionState.TERMINATED, ] ) class QuicConnection: """ A QUIC connection. The state machine is driven by three kinds of sources: - the API user requesting data to be send out (see :meth:`connect`, :meth:`reset_stream`, :meth:`send_ping`, :meth:`send_datagram_frame` and :meth:`send_stream_data`) - data being received from the network (see :meth:`receive_datagram`) - a timer firing (see :meth:`handle_timer`) :param configuration: The QUIC configuration to use. """ __slots__ = ( "_configuration", "_is_client", "_ack_delay", "_close_at", "_close_event", "_connect_called", "_cryptos", "_cryptos_initial", "_crypto_buffers", "_crypto_frame_type", "_crypto_packet_version", "_crypto_retransmitted", "_crypto_streams", "_events", "_handshake_complete", "_handshake_confirmed", "_host_cids", "host_cid", "_host_cid_seq", "_local_ack_delay_exponent", "_local_active_connection_id_limit", "_local_challenges", "_local_initial_source_connection_id", "_local_max_data", "_local_max_stream_data_bidi_local", "_local_max_stream_data_bidi_remote", "_local_max_stream_data_uni", "_local_max_streams_bidi", "_local_max_streams_uni", "_local_next_stream_id_bidi", "_local_next_stream_id_uni", "_loss_at", "_network_paths", "_pacing_at", "_packet_number", "_peer_cid", "_peer_cid_available", "_peer_cid_sequence_numbers", "_peer_retire_prior_to", "_peer_token", "_quic_logger", "_remote_ack_delay_exponent", "_remote_active_connection_id_limit", "_remote_initial_source_connection_id", "_remote_max_idle_timeout", "_remote_max_data", "_remote_max_data_used", "_remote_max_datagram_frame_size", "_remote_max_stream_data_bidi_local", "_remote_max_stream_data_bidi_remote", "_remote_max_stream_data_uni", "_remote_max_streams_bidi", "_remote_max_streams_uni", "_remote_version_information", "_retry_count", "_retry_source_connection_id", "_spaces", "_spin_bit", "_spin_highest_pn", "_state", "_streams", "_streams_dirty_limits", "_streams_queue", "_streams_blocked_bidi", "_streams_blocked_uni", "_streams_finished", "_version", "_version_negotiated_compatible", "_version_negotiated_incompatible", "_original_destination_connection_id", "_logger", "_loss", "_close_pending", "_datagrams_pending", "_handshake_done_pending", "_ping_pending", "_probe_pending", "_retire_connection_ids", "_streams_blocked_pending", "_session_ticket_fetcher", "_session_ticket_handler", "__frame_handlers", "tls", "_local_max_data_used", "_max_datagram_size", "_mtu_probe_sizes", "_mtu_probe_pending", "_initial_source_connection_id", ) def __init__( self, *, configuration: QuicConfiguration, original_destination_connection_id: bytes | None = None, retry_source_connection_id: bytes | None = None, session_ticket_fetcher: tls.SessionTicketFetcher | None = None, session_ticket_handler: tls.SessionTicketHandler | None = None, ) -> None: if configuration.is_client: assert original_destination_connection_id is None, ( "Cannot set original_destination_connection_id for a client" ) assert retry_source_connection_id is None, ( "Cannot set retry_source_connection_id for a client" ) else: assert configuration.certificate is not None, ( "SSL certificate is required for a server" ) assert configuration.private_key is not None, ( "SSL private key is required for a server" ) assert original_destination_connection_id is not None, ( "original_destination_connection_id is required for a server" ) # configuration self._configuration = configuration self._is_client = configuration.is_client self._max_datagram_size = configuration.max_datagram_size self._mtu_probe_sizes: list[int] = ( [s for s in MTU_PROBE_SIZES if s > self._max_datagram_size] if self._is_client and configuration.probe_datagram_size else [] ) self._mtu_probe_pending: int | None = None self._ack_delay = K_GRANULARITY self._close_at: float | None = None self._close_event: events.ConnectionTerminated | None = None self._connect_called = False self._cryptos: dict[tls.Epoch, CryptoPair] = {} self._cryptos_initial: dict[int, CryptoPair] = {} self._crypto_buffers: dict[tls.Epoch, Buffer] = {} self._crypto_frame_type: int | None = None self._crypto_packet_version: int | None = None self._crypto_retransmitted = False self._crypto_streams: dict[tls.Epoch, QuicStream] = {} self._events: deque[events.QuicEvent] = deque() self._handshake_complete = False self._handshake_confirmed = False self._host_cids = [ QuicConnectionId( cid=os.urandom(configuration.connection_id_length), sequence_number=0, stateless_reset_token=os.urandom(16) if not self._is_client else None, was_sent=True, ) ] self.host_cid = self._host_cids[0].cid self._host_cid_seq = 1 self._local_ack_delay_exponent = 3 self._local_active_connection_id_limit = 8 self._local_challenges: dict[bytes, QuicNetworkPath] = {} self._local_initial_source_connection_id = self._host_cids[0].cid self._local_max_data = Limit( frame_type=QuicFrameType.MAX_DATA, name="max_data", value=configuration.max_data, ) self._local_max_stream_data_bidi_local = configuration.max_stream_data self._local_max_stream_data_bidi_remote = configuration.max_stream_data self._local_max_stream_data_uni = configuration.max_stream_data self._local_max_streams_bidi = Limit( frame_type=QuicFrameType.MAX_STREAMS_BIDI, name="max_streams_bidi", value=128, ) self._local_max_streams_uni = Limit( frame_type=QuicFrameType.MAX_STREAMS_UNI, name="max_streams_uni", value=128 ) self._local_next_stream_id_bidi = 0 if self._is_client else 1 self._local_next_stream_id_uni = 2 if self._is_client else 3 self._loss_at: float | None = None self._network_paths: list[QuicNetworkPath] = [] self._pacing_at: float | None = None self._packet_number = 0 self._peer_cid = QuicConnectionId( cid=os.urandom(configuration.connection_id_length), sequence_number=None ) self._peer_cid_available: list[QuicConnectionId] = [] self._peer_cid_sequence_numbers: set[int] = {0} self._peer_retire_prior_to = 0 self._peer_token = b"" self._quic_logger: QuicLoggerTrace | None = None self._remote_ack_delay_exponent = 3 self._remote_active_connection_id_limit = 2 self._remote_initial_source_connection_id: bytes | None = None self._remote_max_idle_timeout = 0.0 # seconds self._remote_max_data = 0 self._remote_max_data_used = 0 self._remote_max_datagram_frame_size: int | None = None self._remote_max_stream_data_bidi_local = 0 self._remote_max_stream_data_bidi_remote = 0 self._remote_max_stream_data_uni = 0 self._remote_max_streams_bidi = 0 self._remote_max_streams_uni = 0 self._remote_version_information: QuicVersionInformation | None = None self._retry_count = 0 self._retry_source_connection_id = retry_source_connection_id self._spaces: dict[tls.Epoch, QuicPacketSpace] = {} self._spin_bit = False self._spin_highest_pn = 0 self._state = QuicConnectionState.FIRSTFLIGHT self._streams: dict[int, QuicStream] = {} self._streams_dirty_limits: set[QuicStream] = set() self._streams_queue: list[QuicStream] = [] self._streams_blocked_bidi: list[QuicStream] = [] self._streams_blocked_uni: list[QuicStream] = [] self._streams_finished: set[int] = set() self._version: int | None = None self._version_negotiated_compatible = False self._version_negotiated_incompatible = False if self._is_client: self._original_destination_connection_id = self._peer_cid.cid else: self._original_destination_connection_id = ( original_destination_connection_id ) # logging self._logger = QuicConnectionAdapter( logger, {"id": dump_cid(self._original_destination_connection_id)} ) if configuration.quic_logger: self._quic_logger = configuration.quic_logger.start_trace( is_client=configuration.is_client, odcid=self._original_destination_connection_id, ) # loss recovery self._loss = QuicPacketRecovery( initial_rtt=configuration.initial_rtt, peer_completed_address_validation=not self._is_client, send_probe=self._send_probe, max_datagram_size=self._max_datagram_size, quic_logger=self._quic_logger, logger=self._logger, ) # things to send self._close_pending = False self._datagrams_pending: deque[bytes] = deque() self._handshake_done_pending = False self._ping_pending: list[int] = [] self._probe_pending = False self._retire_connection_ids: list[int] = [] self._streams_blocked_pending = False # callbacks self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler # frame handlers self.__frame_handlers = { 0x00: (self._handle_padding_frame, EPOCHS("IH01")), 0x01: (self._handle_ping_frame, EPOCHS("IH01")), 0x02: (self._handle_ack_frame, EPOCHS("IH1")), 0x03: (self._handle_ack_frame, EPOCHS("IH1")), 0x04: (self._handle_reset_stream_frame, EPOCHS("01")), 0x05: (self._handle_stop_sending_frame, EPOCHS("01")), 0x06: (self._handle_crypto_frame, EPOCHS("IH1")), 0x07: (self._handle_new_token_frame, EPOCHS("1")), 0x08: (self._handle_stream_frame, EPOCHS("01")), 0x09: (self._handle_stream_frame, EPOCHS("01")), 0x0A: (self._handle_stream_frame, EPOCHS("01")), 0x0B: (self._handle_stream_frame, EPOCHS("01")), 0x0C: (self._handle_stream_frame, EPOCHS("01")), 0x0D: (self._handle_stream_frame, EPOCHS("01")), 0x0E: (self._handle_stream_frame, EPOCHS("01")), 0x0F: (self._handle_stream_frame, EPOCHS("01")), 0x10: (self._handle_max_data_frame, EPOCHS("01")), 0x11: (self._handle_max_stream_data_frame, EPOCHS("01")), 0x12: (self._handle_max_streams_bidi_frame, EPOCHS("01")), 0x13: (self._handle_max_streams_uni_frame, EPOCHS("01")), 0x14: (self._handle_data_blocked_frame, EPOCHS("01")), 0x15: (self._handle_stream_data_blocked_frame, EPOCHS("01")), 0x16: (self._handle_streams_blocked_frame, EPOCHS("01")), 0x17: (self._handle_streams_blocked_frame, EPOCHS("01")), 0x18: (self._handle_new_connection_id_frame, EPOCHS("01")), 0x19: (self._handle_retire_connection_id_frame, EPOCHS("01")), 0x1A: (self._handle_path_challenge_frame, EPOCHS("01")), 0x1B: (self._handle_path_response_frame, EPOCHS("01")), 0x1C: (self._handle_connection_close_frame, EPOCHS("IH01")), 0x1D: (self._handle_connection_close_frame, EPOCHS("01")), 0x1E: (self._handle_handshake_done_frame, EPOCHS("1")), 0x30: (self._handle_datagram_frame, EPOCHS("01")), 0x31: (self._handle_datagram_frame, EPOCHS("01")), } @property def open_outbound_streams(self) -> int: return len(self._streams) @property def max_concurrent_bidi_streams(self) -> int: return self._remote_max_streams_bidi @property def max_concurrent_uni_streams(self) -> int: return self._remote_max_streams_uni def get_cipher(self) -> tls.CipherSuite | None: return self.tls.key_schedule.cipher_suite if self.tls.key_schedule else None def get_peercert(self) -> X509Certificate | None: return self.tls.peer_certificate def get_issuercerts(self) -> list[X509Certificate]: return self.tls.peer_certificate_chain @property def configuration(self) -> QuicConfiguration: return self._configuration @property def original_destination_connection_id(self) -> bytes: return self._original_destination_connection_id def change_connection_id(self) -> None: """ Switch to the next available connection ID and retire the previous one. After calling this method call :meth:`datagrams_to_send` to retrieve data which needs to be sent. """ if self._peer_cid_available: # retire previous CID self._retire_peer_cid(self._peer_cid) # assign new CID self._consume_peer_cid() def close( self, error_code: int = QuicErrorCode.NO_ERROR, frame_type: int | None = None, reason_phrase: str = "", ) -> None: """ Close the connection. :param error_code: An error code indicating why the connection is being closed. :param reason_phrase: A human-readable explanation of why the connection is being closed. """ if self._close_event is None and self._state not in END_STATES: self._close_event = events.ConnectionTerminated( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) self._close_pending = True def connect(self, addr: NetworkAddress, now: float) -> None: """ Initiate the TLS handshake. This method can only be called for clients and a single time. After calling this method call :meth:`datagrams_to_send` to retrieve data which needs to be sent. :param addr: The network address of the remote peer. :param now: The current time. """ assert self._is_client and not self._connect_called, ( "connect() can only be called for clients and a single time" ) self._connect_called = True self._network_paths = [QuicNetworkPath(addr, is_validated=True)] if self._configuration.original_version is not None: self._version = self._configuration.original_version else: self._version = self._configuration.supported_versions[0] self._connect(now=now) def datagrams_to_send(self, now: float) -> list[tuple[bytes, NetworkAddress]]: """ Return a list of `(data, addr)` tuples of datagrams which need to be sent, and the network address to which they need to be sent. After calling this method call :meth:`get_timer` to know when the next timer needs to be set. :param now: The current time. """ network_path = self._network_paths[0] if self._state in END_STATES: return [] # build datagrams builder = QuicPacketBuilder( host_cid=self.host_cid, is_client=self._is_client, max_datagram_size=self._max_datagram_size, packet_number=self._packet_number, peer_cid=self._peer_cid.cid, peer_token=self._peer_token, quic_logger=self._quic_logger, spin_bit=self._spin_bit, version=self._version, ) if self._close_pending: epoch_packet_types = [] if not self._handshake_confirmed: epoch_packet_types += [ (tls.Epoch.INITIAL, QuicPacketType.INITIAL), (tls.Epoch.HANDSHAKE, QuicPacketType.HANDSHAKE), ] epoch_packet_types.append((tls.Epoch.ONE_RTT, QuicPacketType.ONE_RTT)) for epoch, packet_type in epoch_packet_types: crypto = self._cryptos[epoch] if crypto.send.is_valid(): builder.start_packet(packet_type, crypto) self._write_connection_close_frame( builder=builder, epoch=epoch, error_code=self._close_event.error_code, frame_type=self._close_event.frame_type, reason_phrase=self._close_event.reason_phrase, ) self._logger.debug( "Connection close sent (code 0x%X, reason %s)", self._close_event.error_code, self._close_event.reason_phrase, ) self._close_pending = False self._close_begin(is_initiator=True, now=now) else: # congestion control builder.max_flight_bytes = ( self._loss.congestion_window - self._loss.bytes_in_flight ) if ( self._probe_pending and builder.max_flight_bytes < self._max_datagram_size ): builder.max_flight_bytes = self._max_datagram_size # limit data on un-validated network paths if not network_path.is_validated: builder.max_total_bytes = ( network_path.bytes_received * 3 - network_path.bytes_sent ) try: if not self._handshake_confirmed: for epoch in [tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE]: self._write_handshake(builder, epoch, now) self._write_application(builder, network_path, now) except QuicPacketBuilderStop: pass datagrams, packets = builder.flush() # MTU probing — send an oversized PING+PADDING packet last_builder = builder if ( self._mtu_probe_sizes and self._mtu_probe_pending is None and self._handshake_confirmed and self._cryptos[tls.Epoch.ONE_RTT].send.is_valid() ): probe_size = self._mtu_probe_sizes[0] self._mtu_probe_pending = probe_size probe_builder = QuicPacketBuilder( host_cid=self.host_cid, is_client=self._is_client, max_datagram_size=probe_size, packet_number=( builder.packet_number if datagrams else self._packet_number ), peer_cid=self._peer_cid.cid, peer_token=self._peer_token, quic_logger=self._quic_logger, spin_bit=self._spin_bit, version=self._version, ) probe_builder.start_packet( QuicPacketType.ONE_RTT, self._cryptos[tls.Epoch.ONE_RTT] ) probe_builder.start_frame( QuicFrameType.PING, capacity=1, handler=self._on_mtu_probe_delivery, handler_args=(probe_size,), ) pad_size = probe_builder.remaining_flight_space if pad_size > 0: probe_builder._buffer.push_bytes(bytes(pad_size)) probe_datagrams, probe_packets = probe_builder.flush() if probe_datagrams: datagrams.extend(probe_datagrams) packets.extend(probe_packets) last_builder = probe_builder if datagrams: self._packet_number = last_builder.packet_number # register packets sent_handshake = False for packet in packets: packet.sent_time = now self._loss.on_packet_sent( packet=packet, space=self._spaces[packet.epoch] ) if packet.epoch == tls.Epoch.HANDSHAKE: sent_handshake = True # log packet if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_sent", data={ "frames": packet.quic_logger_frames, "header": { "packet_number": packet.packet_number, "packet_type": self._quic_logger.packet_type( packet.packet_type ), "scid": ( "" if packet.packet_type == QuicPacketType.ONE_RTT else dump_cid(self.host_cid) ), "dcid": dump_cid(self._peer_cid.cid), }, "raw": {"length": packet.sent_bytes}, }, ) # check if we can discard initial keys if sent_handshake and self._is_client: self._discard_epoch(tls.Epoch.INITIAL) # return datagrams to send and the destination network address ret = [] for datagram in datagrams: payload_length = len(datagram) network_path.bytes_sent += payload_length ret.append((datagram, network_path.addr)) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="datagrams_sent", data={ "count": 1, "raw": [ { "length": UDP_HEADER_SIZE + payload_length, "payload_length": payload_length, } ], }, ) return ret def get_next_available_stream_id(self, is_unidirectional=False) -> int: """ Return the stream ID for the next stream created by this endpoint. """ if is_unidirectional: return self._local_next_stream_id_uni else: return self._local_next_stream_id_bidi def get_timer(self) -> float | None: """ Return the time at which the timer should fire or None if no timer is needed. """ timer_at = self._close_at if self._state not in END_STATES: # ack timer for space in self._loss.spaces: if space.ack_at is not None and space.ack_at < timer_at: timer_at = space.ack_at # loss detection timer self._loss_at = self._loss.get_loss_detection_time() if self._loss_at is not None and self._loss_at < timer_at: timer_at = self._loss_at # pacing timer if self._pacing_at is not None and self._pacing_at < timer_at: timer_at = self._pacing_at return timer_at def handle_timer(self, now: float) -> None: """ Handle the timer. After calling this method call :meth:`datagrams_to_send` to retrieve data which needs to be sent. :param now: The current time. """ # end of closing period or idle timeout if now >= self._close_at: if self._close_event is None: self._close_event = events.ConnectionTerminated( error_code=QuicErrorCode.INTERNAL_ERROR, frame_type=QuicFrameType.PADDING, reason_phrase="Idle timeout", ) self._close_end() return # loss detection timeout if self._loss_at is not None and now >= self._loss_at: self._logger.debug("Loss detection triggered") self._loss.on_loss_detection_timeout(now=now) def next_event(self) -> events.QuicEvent | None: """ Retrieve the next event from the event buffer. Returns `None` if there are no buffered events. """ try: return self._events.popleft() except IndexError: return None def receive_datagram(self, data: bytes, addr: NetworkAddress, now: float) -> None: """ Handle an incoming datagram. After calling this method call :meth:`datagrams_to_send` to retrieve data which needs to be sent. :param data: The datagram which was received. :param addr: The network address from which the datagram was received. :param now: The current time. """ # stop handling packets when closing if self._state in END_STATES: return payload_length = len(data) # log datagram if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="datagrams_received", data={ "count": 1, "raw": [ { "length": UDP_HEADER_SIZE + payload_length, "payload_length": payload_length, } ], }, ) # For anti-amplification purposes, servers need to keep track of the # amount of data received on unvalidated network paths. We must count the # entire datagram size regardless of whether packets are processed or # dropped. # # This is particularly important when talking to clients who pad # datagrams containing INITIAL packets by appending bytes after the # long-header packets, which is legitimate behaviour. # # https://datatracker.ietf.org/doc/html/rfc9000#section-8.1 network_path = self._find_network_path(addr) if not network_path.is_validated: network_path.bytes_received += payload_length # for servers, arm the idle timeout on the first datagram if self._close_at is None: self._close_at = now + self._configuration.idle_timeout buf = Buffer(data=data) while not buf.eof(): start_off = buf.tell() try: header = pull_quic_header( buf, host_cid_length=self._configuration.connection_id_length ) except ValueError: if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "header_parse_error", "raw": {"length": buf.capacity - start_off}, }, ) return # check destination CID matches destination_cid_seq: int | None = None for connection_id in self._host_cids: if header.destination_cid == connection_id.cid: destination_cid_seq = connection_id.sequence_number break if self._is_client and destination_cid_seq is None: if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={"trigger": "unknown_connection_id"}, ) return # Handle version negotiation packet. if header.packet_type == QuicPacketType.VERSION_NEGOTIATION: self._receive_version_negotiation_packet(header=header, now=now) return # Check long header packet protocol version. if ( header.version is not None and header.version not in self._configuration.supported_versions ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unsupported_version", "raw": {"length": header.packet_length}, }, ) return # handle retry packet if header.packet_type == QuicPacketType.RETRY: self._receive_retry_packet( header=header, packet_without_tag=buf.data_slice( start_off, buf.tell() - RETRY_INTEGRITY_TAG_SIZE ), now=now, ) return crypto_frame_required = False # server initialization if not self._is_client and self._state is QuicConnectionState.FIRSTFLIGHT: assert header.packet_type == QuicPacketType.INITIAL, ( "first packet must be INITIAL" ) crypto_frame_required = True self._network_paths = [network_path] self._version = header.version self._initialize(header.destination_cid) # Determine crypto and packet space. epoch = get_epoch(header.packet_type) if epoch == tls.Epoch.INITIAL: crypto = self._cryptos_initial[header.version] else: crypto = self._cryptos[epoch] if epoch == tls.Epoch.ZERO_RTT: space = self._spaces[tls.Epoch.ONE_RTT] else: space = self._spaces[epoch] # decrypt packet encrypted_off = buf.tell() - start_off end_off = start_off + header.packet_length buf.seek(end_off) try: plain_header, plain_payload, packet_number = crypto.decrypt_packet( data[start_off:end_off], encrypted_off, space.expected_packet_number ) except KeyUnavailableError as exc: self._logger.debug(exc) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "key_unavailable", "raw": {"length": header.packet_length}, }, ) # If a client receives HANDSHAKE or 1-RTT packets before it has # handshake keys, it can assume that the server's INITIAL was lost. if ( self._is_client and epoch in (tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT) and not self._crypto_retransmitted ): self._loss.reschedule_data(now=now) self._crypto_retransmitted = True continue except CryptoError as exc: self._logger.debug(exc) if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "payload_decrypt_error", "raw": {"length": header.packet_length}, }, ) continue # check reserved bits if header.packet_type == QuicPacketType.ONE_RTT: reserved_mask = 0x18 else: reserved_mask = 0x0C if plain_header[0] & reserved_mask: self.close( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Reserved bits must be zero", ) return # log packet quic_logger_frames: list[dict] | None = None if self._quic_logger is not None: quic_logger_frames = [] self._quic_logger.log_event( category="transport", event="packet_received", data={ "frames": quic_logger_frames, "header": { "packet_number": packet_number, "packet_type": self._quic_logger.packet_type( header.packet_type ), "dcid": dump_cid(header.destination_cid), "scid": dump_cid(header.source_cid), }, "raw": {"length": header.packet_length}, }, ) # raise expected packet number if packet_number > space.expected_packet_number: space.expected_packet_number = packet_number + 1 # discard initial keys and packet space if not self._is_client and epoch == tls.Epoch.HANDSHAKE: self._discard_epoch(tls.Epoch.INITIAL) # update state if self._peer_cid.sequence_number is None: self._peer_cid.cid = header.source_cid self._peer_cid.sequence_number = 0 if self._state is QuicConnectionState.FIRSTFLIGHT: self._remote_initial_source_connection_id = header.source_cid self._set_state(QuicConnectionState.CONNECTED) # update spin bit if ( header.packet_type == QuicPacketType.ONE_RTT and packet_number > self._spin_highest_pn ): spin_bit = get_spin_bit(plain_header[0]) if self._is_client: self._spin_bit = not spin_bit else: self._spin_bit = spin_bit self._spin_highest_pn = packet_number if self._quic_logger is not None: self._quic_logger.log_event( category="connectivity", event="spin_bit_updated", data={"state": self._spin_bit}, ) # handle payload context = QuicReceiveContext( epoch=epoch, host_cid=header.destination_cid, network_path=network_path, quic_logger_frames=quic_logger_frames, time=now, version=header.version, ) try: is_ack_eliciting, is_probing = self._payload_received( context, plain_payload, crypto_frame_required ) except QuicConnectionError as exc: self._logger.debug(exc) self.close( error_code=exc.error_code, frame_type=exc.frame_type, reason_phrase=exc.reason_phrase, ) if self._state in END_STATES or self._close_pending: return # update idle timeout self._close_at = now + self._configuration.idle_timeout # handle migration if ( not self._is_client and context.host_cid != self.host_cid and epoch == tls.Epoch.ONE_RTT ): self._logger.debug( "Peer switching to CID %s (%d)", dump_cid(context.host_cid), destination_cid_seq, ) self.host_cid = context.host_cid self.change_connection_id() # update network path if not network_path.is_validated and epoch == tls.Epoch.HANDSHAKE: self._logger.debug( "Network path %s validated by handshake", network_path.addr ) network_path.is_validated = True if network_path not in self._network_paths: self._network_paths.append(network_path) idx = self._network_paths.index(network_path) if idx and not is_probing and packet_number > space.largest_received_packet: self._logger.debug("Network path %s promoted", network_path.addr) self._network_paths.pop(idx) self._network_paths.insert(0, network_path) # record packet as received if not space.discarded: if packet_number > space.largest_received_packet: space.largest_received_packet = packet_number space.largest_received_time = now space.ack_queue.add(packet_number) if is_ack_eliciting and space.ack_at is None: space.ack_at = now + self._ack_delay def request_key_update(self) -> None: """ Request an update of the encryption keys. """ assert self._handshake_complete, "cannot change key before handshake completes" self._cryptos[tls.Epoch.ONE_RTT].update_key() def reset_stream(self, stream_id: int, error_code: int) -> None: """ Abruptly terminate the sending part of a stream. This method has no effect if a reset has already been triggered either by a call to :meth:`reset_stream` or by the reception of a STOP_SENDING frame. :param stream_id: The stream's ID. :param error_code: An error code indicating why the stream is being reset. """ stream = self._get_or_create_stream_for_send(stream_id) stream.sender.reset(error_code) def send_ping(self, uid: int) -> None: """ Send a PING frame to the peer. :param uid: A unique ID for this PING. """ self._ping_pending.append(uid) def send_datagram_frame(self, data: bytes) -> None: """ Send a DATAGRAM frame. :param data: The data to be sent. """ self._datagrams_pending.append(data) def send_stream_data( self, stream_id: int, data: bytes, end_stream: bool = False ) -> None: """ Send data on the specific stream. :param stream_id: The stream's ID. :param data: The data to be sent. :param end_stream: If set to `True`, the FIN bit will be set. """ stream = self._get_or_create_stream_for_send(stream_id) stream.sender.write(data, end_stream=end_stream) def stop_stream(self, stream_id: int, error_code: int) -> None: """ Request termination of the receiving part of a stream. :param stream_id: The stream's ID. :param error_code: An error code indicating why the stream is being stopped. """ if not check_stream_id_for_receiving(self._is_client, stream_id): raise ValueError( "Cannot stop receiving on a local-initiated unidirectional stream" ) stream = self._streams.get(stream_id, None) if stream is None: raise ValueError("Cannot stop receiving on an unknown stream") stream.receiver.stop(error_code) # Private def _alpn_handler(self, alpn_protocol: str) -> None: """ Callback which is invoked by the TLS engine at most once, when the ALPN negotiation completes. At this point, TLS extensions have been received so we can parse the transport parameters. """ # Parse the remote transport parameters. for ext_type, ext_data in self.tls.received_extensions: if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: self._parse_transport_parameters(ext_data) break else: raise QuicConnectionError( error_code=QuicErrorCode.CRYPTO_ERROR + tls.AlertDescription.missing_extension, frame_type=self._crypto_frame_type, reason_phrase="No QUIC transport parameters received", ) # For servers, determine the Negotiated Version. if not self._is_client and not self._version_negotiated_compatible: if self._remote_version_information is not None: # Pick the first version we support in the client's available versions, # which is compatible with the current version. for version in self._remote_version_information.available_versions: if version == self._version: # Stay with the current version. break elif ( version in self._configuration.supported_versions and is_version_compatible(self._version, version) ): # Change version. self._version = version self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[ version ] # Update our transport parameters to reflect the chosen version. self.tls.handshake_extensions = [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, self._serialize_transport_parameters(), ) ] break self._version_negotiated_compatible = True self._logger.debug( "Negotiated protocol version %s", pretty_protocol_version(self._version) ) # Notify the application. self._events.append(events.ProtocolNegotiated(alpn_protocol=alpn_protocol)) def _assert_stream_can_receive(self, frame_type: int, stream_id: int) -> None: """ Check the specified stream can receive data or raises a QuicConnectionError. """ if not check_stream_id_for_receiving(self._is_client, stream_id): raise QuicConnectionError( error_code=QuicErrorCode.STREAM_STATE_ERROR, frame_type=frame_type, reason_phrase="Stream is send-only", ) def _assert_stream_can_send(self, frame_type: int, stream_id: int) -> None: """ Check the specified stream can send data or raises a QuicConnectionError. """ if not check_stream_id_for_sending(self._is_client, stream_id): raise QuicConnectionError( error_code=QuicErrorCode.STREAM_STATE_ERROR, frame_type=frame_type, reason_phrase="Stream is receive-only", ) def _consume_peer_cid(self) -> None: """ Update the destination connection ID by taking the next available connection ID provided by the peer. """ self._peer_cid = self._peer_cid_available.pop(0) self._logger.debug( "Switching to CID %s (%d)", dump_cid(self._peer_cid.cid), self._peer_cid.sequence_number, ) def _close_begin(self, is_initiator: bool, now: float) -> None: """ Begin the close procedure. """ self._close_at = now + 3 * self._loss.get_probe_timeout() if is_initiator: self._set_state(QuicConnectionState.CLOSING) else: self._set_state(QuicConnectionState.DRAINING) def _close_end(self) -> None: """ End the close procedure. """ self._close_at = None for epoch in self._spaces.keys(): self._discard_epoch(epoch) self._events.append(self._close_event) self._set_state(QuicConnectionState.TERMINATED) # signal log end if self._quic_logger is not None: self._configuration.quic_logger.end_trace(self._quic_logger) self._quic_logger = None def _connect(self, now: float) -> None: """ Start the client handshake. """ assert self._is_client self._close_at = now + self._configuration.idle_timeout self._initialize(self._peer_cid.cid) self.tls.handle_message(b"", self._crypto_buffers) self._push_crypto_data() def _discard_epoch(self, epoch: tls.Epoch) -> None: if not self._spaces[epoch].discarded: self._logger.debug("Discarding epoch %s", epoch) self._cryptos[epoch].teardown() if epoch == tls.Epoch.INITIAL: # Tear the crypto pairs, but do not log the event, # to avoid duplicate log entries. for crypto in self._cryptos_initial.values(): crypto.recv._teardown_cb = NoCallback crypto.send._teardown_cb = NoCallback crypto.teardown() self._loss.discard_space(self._spaces[epoch]) self._spaces[epoch].discarded = True def _find_network_path(self, addr: NetworkAddress) -> QuicNetworkPath: # check existing network paths for network_path in self._network_paths: if network_path.addr == addr: return network_path # new network path self._logger.debug("Network path %s discovered", addr) return QuicNetworkPath(addr) def _get_or_create_stream(self, frame_type: int, stream_id: int) -> QuicStream: """ Get or create a stream in response to a received frame. """ if stream_id in self._streams_finished: # the stream was created, but its state was since discarded raise StreamFinishedError stream = self._streams.get(stream_id, None) if stream is None: # check initiator if stream_is_client_initiated(stream_id) is self._is_client: raise QuicConnectionError( error_code=QuicErrorCode.STREAM_STATE_ERROR, frame_type=frame_type, reason_phrase="Wrong stream initiator", ) # determine limits if stream_is_unidirectional(stream_id): max_stream_data_local = self._local_max_stream_data_uni max_stream_data_remote = 0 max_streams = self._local_max_streams_uni else: max_stream_data_local = self._local_max_stream_data_bidi_remote max_stream_data_remote = self._remote_max_stream_data_bidi_local max_streams = self._local_max_streams_bidi # check max streams stream_count = (stream_id // 4) + 1 if stream_count > max_streams.value: raise QuicConnectionError( error_code=QuicErrorCode.STREAM_LIMIT_ERROR, frame_type=frame_type, reason_phrase="Too many streams open", ) elif stream_count > max_streams.used: max_streams.used = stream_count # create stream self._logger.debug(f"Stream {stream_id} created by peer") stream = self._streams[stream_id] = QuicStream( stream_id=stream_id, max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, writable=not stream_is_unidirectional(stream_id), ) self._streams_queue.append(stream) return stream def _get_or_create_stream_for_send(self, stream_id: int) -> QuicStream: """ Get or create a QUIC stream in order to send data to the peer. This always occurs as a result of an API call. """ if not check_stream_id_for_sending(self._is_client, stream_id): raise ValueError("Cannot send data on peer-initiated unidirectional stream") stream = self._streams.get(stream_id, None) if stream is None: # check initiator if stream_is_client_initiated(stream_id) is not self._is_client: raise ValueError("Cannot send data on unknown peer-initiated stream") # determine limits if stream_is_unidirectional(stream_id): max_stream_data_local = 0 max_stream_data_remote = self._remote_max_stream_data_uni max_streams = self._remote_max_streams_uni streams_blocked = self._streams_blocked_uni else: max_stream_data_local = self._local_max_stream_data_bidi_local max_stream_data_remote = self._remote_max_stream_data_bidi_remote max_streams = self._remote_max_streams_bidi streams_blocked = self._streams_blocked_bidi # create stream is_unidirectional = stream_is_unidirectional(stream_id) stream = self._streams[stream_id] = QuicStream( stream_id=stream_id, max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, readable=not is_unidirectional, ) self._streams_queue.append(stream) if is_unidirectional: self._local_next_stream_id_uni = stream_id + 4 else: self._local_next_stream_id_bidi = stream_id + 4 # mark stream as blocked if needed if stream_id // 4 >= max_streams: stream.is_blocked = True streams_blocked.append(stream) self._streams_blocked_pending = True return stream def _handle_session_ticket(self, session_ticket: tls.SessionTicket) -> None: if ( session_ticket.max_early_data_size is not None and session_ticket.max_early_data_size != MAX_EARLY_DATA ): raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.CRYPTO, reason_phrase="Invalid max_early_data value " f"{session_ticket.max_early_data_size}", ) self._session_ticket_handler(session_ticket) def _initialize(self, peer_cid: bytes) -> None: # TLS self.tls = tls.Context( alpn_protocols=self._configuration.alpn_protocols, cadata=self._configuration.cadata, cafile=self._configuration.cafile, capath=self._configuration.capath, cipher_suites=self.configuration.cipher_suites, is_client=self._is_client, logger=self._logger, max_early_data=None if self._is_client else MAX_EARLY_DATA, server_name=self._configuration.server_name, verify_mode=self._configuration.verify_mode, hostname_checks_common_name=self._configuration.hostname_checks_common_name, assert_fingerprint=self._configuration.assert_fingerprint, verify_hostname=self._configuration.verify_hostname, ) if self._configuration.certificate is not None and not isinstance( self._configuration.certificate, X509Certificate ): raise RuntimeError( # Defensive: migration from cryptography "qh3 v1.0+ no longer support passing cryptography " "certificate objects within a QuicConfiguration object. " "Use configuration.load_cert_chain(...) instead using " "PEM encoded values." ) self.tls.certificate = self._configuration.certificate if self._configuration.certificate_chain and not isinstance( self._configuration.certificate_chain[0], X509Certificate ): raise RuntimeError( # Defensive: migration from cryptography "qh3 v1.0+ no longer support passing cryptography " "certificate objects within a QuicConfiguration object. " "Use configuration.load_cert_chain(...) instead using " "PEM encoded values." ) self.tls.certificate_chain = self._configuration.certificate_chain if self._configuration.private_key and "cryptography" in str( type(self._configuration.private_key) ): raise RuntimeError( # Defensive: migration from cryptography "qh3 v1.0+ no longer support passing cryptography " "private key object within a QuicConfiguration object. " "Use configuration.load_cert_chain(...) instead using " "PEM encoded values." ) self.tls.certificate_private_key = self._configuration.private_key self.tls.handshake_extensions = [ ( tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS, self._serialize_transport_parameters(), ) ] # TLS session resumption session_ticket = self._configuration.session_ticket if ( self._is_client and session_ticket is not None and session_ticket.is_valid and session_ticket.server_name == self._configuration.server_name ): self.tls.session_ticket = self._configuration.session_ticket # parse saved QUIC transport parameters - for 0-RTT if session_ticket.max_early_data_size == MAX_EARLY_DATA: for ext_type, ext_data in session_ticket.other_extensions: if ext_type == tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS: self._parse_transport_parameters( ext_data, from_session_ticket=True ) break # TLS callbacks self.tls.alpn_cb = self._alpn_handler if self._session_ticket_fetcher is not None: self.tls.get_session_ticket_cb = self._session_ticket_fetcher if self._session_ticket_handler is not None: self.tls.new_session_ticket_cb = self._handle_session_ticket self.tls.update_traffic_key_cb = self._update_traffic_key # packet spaces def create_crypto_pair(epoch: tls.Epoch) -> CryptoPair: epoch_name = ["initial", "0rtt", "handshake", "1rtt"][epoch.value] secret_names = [ f"server_{epoch_name}_secret", f"client_{epoch_name}_secret", ] recv_secret_name = secret_names[not self._is_client] send_secret_name = secret_names[self._is_client] return CryptoPair( recv_setup_cb=partial(self._log_key_updated, recv_secret_name), recv_teardown_cb=partial(self._log_key_retired, recv_secret_name), send_setup_cb=partial(self._log_key_updated, send_secret_name), send_teardown_cb=partial(self._log_key_retired, send_secret_name), ) # To enable version negotiation, setup encryption keys for all # our supported versions. self._cryptos_initial = {} for version in self._configuration.supported_versions: pair = CryptoPair() pair.setup_initial(cid=peer_cid, is_client=self._is_client, version=version) self._cryptos_initial[version] = pair self._cryptos = { epoch: create_crypto_pair(epoch) for epoch in ( tls.Epoch.INITIAL, tls.Epoch.ZERO_RTT, tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT, ) } self._cryptos[tls.Epoch.INITIAL] = self._cryptos_initial[self._version] self._crypto_buffers = { tls.Epoch.INITIAL: Buffer(capacity=CRYPTO_BUFFER_SIZE), tls.Epoch.HANDSHAKE: Buffer(capacity=CRYPTO_BUFFER_SIZE), tls.Epoch.ONE_RTT: Buffer(capacity=CRYPTO_BUFFER_SIZE), } self._crypto_streams = { tls.Epoch.INITIAL: QuicStream(), tls.Epoch.HANDSHAKE: QuicStream(), tls.Epoch.ONE_RTT: QuicStream(), } self._spaces = { tls.Epoch.INITIAL: QuicPacketSpace(), tls.Epoch.HANDSHAKE: QuicPacketSpace(), tls.Epoch.ONE_RTT: QuicPacketSpace(), } self._loss.spaces = list(self._spaces.values()) def _handle_ack_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle an ACK frame. """ ack_rangeset, ack_delay_encoded = pull_ack_frame(buf) if frame_type == QuicFrameType.ACK_ECN: buf.pull_uint_var() buf.pull_uint_var() buf.pull_uint_var() ack_delay = (ack_delay_encoded << self._remote_ack_delay_exponent) / 1000000 # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_ack_frame(ack_rangeset, ack_delay) ) # check whether peer completed address validation if not self._loss.peer_completed_address_validation and context.epoch in ( tls.Epoch.HANDSHAKE, tls.Epoch.ONE_RTT, ): self._loss.peer_completed_address_validation = True self._loss.on_ack_received( space=self._spaces[context.epoch], ack_rangeset=ack_rangeset, ack_delay=ack_delay, now=context.time, ) def _handle_connection_close_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a CONNECTION_CLOSE frame. """ error_code = buf.pull_uint_var() if frame_type == QuicFrameType.TRANSPORT_CLOSE: frame_type = buf.pull_uint_var() else: frame_type = None reason_length = buf.pull_uint_var() try: reason_phrase = buf.pull_bytes(reason_length).decode("utf8") except UnicodeDecodeError: reason_phrase = "" # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_close_frame( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) ) self._logger.debug( "Connection close received (code 0x%X, reason %s)", error_code, reason_phrase, ) if self._close_event is None: self._close_event = events.ConnectionTerminated( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) self._close_begin(is_initiator=False, now=context.time) def _handle_crypto_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a CRYPTO frame. """ offset = buf.pull_uint_var() length = buf.pull_uint_var() if offset + length > UINT_VAR_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="offset + length cannot exceed 2^62 - 1", ) frame = QuicStreamFrame(offset=offset, data=buf.pull_bytes(length)) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_crypto_frame(frame) ) stream = self._crypto_streams[context.epoch] pending = offset + length - stream.receiver.starting_offset() if pending > MAX_PENDING_CRYPTO: raise QuicConnectionError( error_code=QuicErrorCode.CRYPTO_BUFFER_EXCEEDED, frame_type=frame_type, reason_phrase="too much crypto buffering", ) event = stream.receiver.handle_frame(frame) if event is not None: # Pass data to TLS layer, which may cause calls to: # - _alpn_handler # - _update_traffic_key self._crypto_frame_type = frame_type self._crypto_packet_version = context.version try: self.tls.handle_message(event.data, self._crypto_buffers) self._push_crypto_data() except tls.Alert as exc: raise QuicConnectionError( error_code=QuicErrorCode.CRYPTO_ERROR + int(exc.description), frame_type=frame_type, reason_phrase=str(exc), ) # update current epoch if not self._handshake_complete and self.tls.state in [ tls.State.CLIENT_POST_HANDSHAKE, tls.State.SERVER_POST_HANDSHAKE, ]: self._handshake_complete = True # for servers, the handshake is now confirmed if not self._is_client: self._discard_epoch(tls.Epoch.HANDSHAKE) self._handshake_confirmed = True self._handshake_done_pending = True self._replenish_connection_ids() self._events.append( events.HandshakeCompleted( alpn_protocol=self.tls.alpn_negotiated, early_data_accepted=self.tls.early_data_accepted, session_resumed=self.tls.session_resumed, ) ) self._unblock_streams(is_unidirectional=False) self._unblock_streams(is_unidirectional=True) self._logger.debug( "ALPN negotiated protocol %s", self.tls.alpn_negotiated ) else: self._logger.debug( "Duplicate CRYPTO data received for epoch %s", context.epoch ) # if a server receives duplicate CRYPTO in an INITIAL packet, # it can assume the client did not receive the server's CRYPTO if ( not self._is_client and context.epoch == tls.Epoch.INITIAL and not self._crypto_retransmitted ): self._loss.reschedule_data(now=context.time) self._crypto_retransmitted = True def _handle_data_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a DATA_BLOCKED frame. """ limit = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_data_blocked_frame(limit=limit) ) def _handle_datagram_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a DATAGRAM frame. """ start = buf.tell() if frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH: length = buf.pull_uint_var() else: length = buf.capacity - start data = buf.pull_bytes(length) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_datagram_frame(length=length) ) # check frame is allowed if ( self._configuration.max_datagram_frame_size is None or buf.tell() - start >= self._configuration.max_datagram_frame_size ): raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Unexpected DATAGRAM frame", ) self._events.append(events.DatagramFrameReceived(data=data)) def _handle_handshake_done_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a HANDSHAKE_DONE frame. """ # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_handshake_done_frame() ) if not self._is_client: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Clients must not send HANDSHAKE_DONE frames", ) # for clients, the handshake is now confirmed if not self._handshake_confirmed: self._discard_epoch(tls.Epoch.HANDSHAKE) self._handshake_confirmed = True self._loss.peer_completed_address_validation = True def _handle_max_data_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_DATA frame. This adjusts the total amount of we can send to the peer. """ max_data = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=frame_type, maximum=max_data ) ) if max_data > self._remote_max_data: self._logger.debug("Remote max_data raised to %d", max_data) self._remote_max_data = max_data def _handle_max_stream_data_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_STREAM_DATA frame. This adjusts the amount of data we can send on a specific stream. """ stream_id = buf.pull_uint_var() max_stream_data = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_max_stream_data_frame( maximum=max_stream_data, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_send(frame_type, stream_id) stream = self._get_or_create_stream(frame_type, stream_id) if max_stream_data > stream.max_stream_data_remote: self._logger.debug( "Stream %d remote max_stream_data raised to %d", stream_id, max_stream_data, ) stream.max_stream_data_remote = max_stream_data def _handle_max_streams_bidi_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_STREAMS_BIDI frame. This raises number of bidirectional streams we can initiate to the peer. """ max_streams = buf.pull_uint_var() if max_streams > STREAM_COUNT_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Maximum Streams cannot exceed 2^60", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=frame_type, maximum=max_streams ) ) if max_streams > self._remote_max_streams_bidi: self._logger.debug("Remote max_streams_bidi raised to %d", max_streams) self._remote_max_streams_bidi = max_streams self._unblock_streams(is_unidirectional=False) def _handle_max_streams_uni_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a MAX_STREAMS_UNI frame. This raises number of unidirectional streams we can initiate to the peer. """ max_streams = buf.pull_uint_var() if max_streams > STREAM_COUNT_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Maximum Streams cannot exceed 2^60", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=frame_type, maximum=max_streams ) ) if max_streams > self._remote_max_streams_uni: self._logger.debug("Remote max_streams_uni raised to %d", max_streams) self._remote_max_streams_uni = max_streams self._unblock_streams(is_unidirectional=True) def _handle_new_connection_id_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a NEW_CONNECTION_ID frame. """ sequence_number = buf.pull_uint_var() retire_prior_to = buf.pull_uint_var() length = buf.pull_uint8() connection_id = buf.pull_bytes(length) stateless_reset_token = buf.pull_bytes(STATELESS_RESET_TOKEN_SIZE) if not connection_id or len(connection_id) > CONNECTION_ID_MAX_SIZE: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Length must be greater than 0 and less than 20", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_new_connection_id_frame( connection_id=connection_id, retire_prior_to=retire_prior_to, sequence_number=sequence_number, stateless_reset_token=stateless_reset_token, ) ) # sanity check if retire_prior_to > sequence_number: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Retire Prior To is greater than Sequence Number", ) # only accept retire_prior_to if it is bigger than the one we know self._peer_retire_prior_to = max(retire_prior_to, self._peer_retire_prior_to) # determine which CIDs to retire change_cid = False retire = [ cid for cid in self._peer_cid_available if cid.sequence_number < self._peer_retire_prior_to ] if self._peer_cid.sequence_number < retire_prior_to: change_cid = True retire.insert(0, self._peer_cid) # update available CIDs self._peer_cid_available = [ cid for cid in self._peer_cid_available if cid.sequence_number >= self._peer_retire_prior_to ] if ( sequence_number >= self._peer_retire_prior_to and sequence_number not in self._peer_cid_sequence_numbers ): self._peer_cid_available.append( QuicConnectionId( cid=connection_id, sequence_number=sequence_number, stateless_reset_token=stateless_reset_token, ) ) self._peer_cid_sequence_numbers.add(sequence_number) # retire previous CIDs for quic_connection_id in retire: self._retire_peer_cid(quic_connection_id) # assign new CID if we retired the active one if change_cid: self._consume_peer_cid() # check number of active connection IDs, including the selected one if 1 + len(self._peer_cid_available) > self._local_active_connection_id_limit: raise QuicConnectionError( error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR, frame_type=frame_type, reason_phrase="Too many active connection IDs", ) # Check the number of retired connection IDs pending, though with a safer limit # than the 2x recommended in section 5.1.2 of the RFC. Note that we are doing # the check here and not in _retire_peer_cid() because we know the frame type to # use here, and because it is the new connection id path that is potentially # dangerous. We may transiently go a bit over the limit due to unacked frames # getting added back to the list, but that's ok as it is bounded. if len(self._retire_connection_ids) > min( self._local_active_connection_id_limit * 4, MAX_PENDING_RETIRES ): raise QuicConnectionError( error_code=QuicErrorCode.CONNECTION_ID_LIMIT_ERROR, frame_type=frame_type, reason_phrase="Too many pending retired connection IDs", ) def _handle_new_token_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a NEW_TOKEN frame. """ length = buf.pull_uint_var() token = buf.pull_bytes(length) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_new_token_frame(token=token) ) if not self._is_client: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Clients must not send NEW_TOKEN frames", ) def _handle_padding_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PADDING frame. """ # consume padding pos = buf.tell() for byte in buf.data_slice(pos, buf.capacity): if byte: break pos += 1 buf.seek(pos) # log frame if self._quic_logger is not None: context.quic_logger_frames.append(self._quic_logger.encode_padding_frame()) def _handle_path_challenge_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PATH_CHALLENGE frame. """ data = buf.pull_bytes(8) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_path_challenge_frame(data=data) ) # Append the new path challenge unless our limit was reached. # This is technically not compliant with RFC 9000 but it's needed # to avoid resource exhaustion attacks. if len(context.network_path.remote_challenges) < MAX_REMOTE_CHALLENGES: context.network_path.remote_challenges.append(data) def _handle_path_response_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PATH_RESPONSE frame. """ data = buf.pull_bytes(8) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_path_response_frame(data=data) ) try: network_path = self._local_challenges.pop(data) except KeyError: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Response does not match challenge", ) self._logger.debug("Network path %s validated by challenge", network_path.addr) network_path.is_validated = True def _handle_ping_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a PING frame. """ # log frame if self._quic_logger is not None: context.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) def _handle_reset_stream_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a RESET_STREAM frame. """ stream_id = buf.pull_uint_var() error_code = buf.pull_uint_var() final_size = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_reset_stream_frame( error_code=error_code, final_size=final_size, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_receive(frame_type, stream_id) # check flow-control limits stream = self._get_or_create_stream(frame_type, stream_id) if final_size > stream.max_stream_data_local: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over stream data limit", ) newly_received = max(0, final_size - stream.receiver.highest_offset) if self._local_max_data.used + newly_received > self._local_max_data.value: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over connection data limit", ) # process reset self._logger.debug( "Stream %d reset by peer (error code %d, final size %d)", stream_id, error_code, final_size, ) try: event = stream.receiver.handle_reset( error_code=error_code, final_size=final_size ) except FinalSizeError as exc: raise QuicConnectionError( error_code=QuicErrorCode.FINAL_SIZE_ERROR, frame_type=frame_type, reason_phrase=str(exc), ) if event is not None: self._events.append(event) self._local_max_data.used += newly_received if newly_received > 0: self._streams_dirty_limits.add(stream) def _handle_retire_connection_id_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a RETIRE_CONNECTION_ID frame. """ sequence_number = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_retire_connection_id_frame(sequence_number) ) if sequence_number >= self._host_cid_seq: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Cannot retire unknown connection ID", ) # find the connection ID by sequence number for index, connection_id in enumerate(self._host_cids): if connection_id.sequence_number == sequence_number: if connection_id.cid == context.host_cid: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Cannot retire current connection ID", ) self._logger.debug( "Peer retiring CID %s (%d)", dump_cid(connection_id.cid), connection_id.sequence_number, ) del self._host_cids[index] self._events.append( events.ConnectionIdRetired(connection_id=connection_id.cid) ) break # issue a new connection ID self._replenish_connection_ids() def _handle_stop_sending_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STOP_SENDING frame. """ stream_id = buf.pull_uint_var() error_code = buf.pull_uint_var() # application error code # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_stop_sending_frame( error_code=error_code, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_send(frame_type, stream_id) # reset the stream stream = self._get_or_create_stream(frame_type, stream_id) stream.sender.reset(error_code=QuicErrorCode.NO_ERROR) self._events.append( events.StopSendingReceived(error_code=error_code, stream_id=stream_id) ) def _handle_stream_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STREAM frame. """ stream_id = buf.pull_uint_var() if frame_type & 4: offset = buf.pull_uint_var() else: offset = 0 if frame_type & 2: length = buf.pull_uint_var() else: length = buf.capacity - buf.tell() if offset + length > UINT_VAR_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="offset + length cannot exceed 2^62 - 1", ) frame = QuicStreamFrame( offset=offset, data=buf.pull_bytes(length), fin=bool(frame_type & 1) ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_stream_frame(frame, stream_id=stream_id) ) # check stream direction self._assert_stream_can_receive(frame_type, stream_id) # check flow-control limits stream = self._get_or_create_stream(frame_type, stream_id) if offset + length > stream.max_stream_data_local: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over stream data limit", ) newly_received = max(0, offset + length - stream.receiver.highest_offset) if self._local_max_data.used + newly_received > self._local_max_data.value: raise QuicConnectionError( error_code=QuicErrorCode.FLOW_CONTROL_ERROR, frame_type=frame_type, reason_phrase="Over connection data limit", ) # process data try: event = stream.receiver.handle_frame(frame) except FinalSizeError as exc: raise QuicConnectionError( error_code=QuicErrorCode.FINAL_SIZE_ERROR, frame_type=frame_type, reason_phrase=str(exc), ) if event is not None: self._events.append(event) self._local_max_data.used += newly_received if newly_received > 0: self._streams_dirty_limits.add(stream) def _handle_stream_data_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STREAM_DATA_BLOCKED frame. """ stream_id = buf.pull_uint_var() limit = buf.pull_uint_var() # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_stream_data_blocked_frame( limit=limit, stream_id=stream_id ) ) # check stream direction self._assert_stream_can_receive(frame_type, stream_id) self._get_or_create_stream(frame_type, stream_id) def _handle_streams_blocked_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: """ Handle a STREAMS_BLOCKED frame. """ limit = buf.pull_uint_var() if limit > STREAM_COUNT_MAX: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Maximum Streams cannot exceed 2^60", ) # log frame if self._quic_logger is not None: context.quic_logger_frames.append( self._quic_logger.encode_streams_blocked_frame( is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, limit=limit, ) ) def _log_key_retired(self, key_type: str, trigger: str) -> None: """ Log a key retirement. """ if self._quic_logger is not None: self._quic_logger.log_event( category="security", event="key_retired", data={"key_type": key_type, "trigger": trigger}, ) def _log_key_updated(self, key_type: str, trigger: str) -> None: """ Log a key update. """ if self._quic_logger is not None: self._quic_logger.log_event( category="security", event="key_updated", data={"key_type": key_type, "trigger": trigger}, ) def _on_ack_delivery( self, delivery: QuicDeliveryState, space: QuicPacketSpace, highest_acked: int ) -> None: """ Callback when an ACK frame is acknowledged or lost. """ if delivery == QuicDeliveryState.ACKED: space.ack_queue.subtract(0, highest_acked + 1) def _on_connection_limit_delivery( self, delivery: QuicDeliveryState, limit: Limit ) -> None: """ Callback when a MAX_DATA or MAX_STREAMS frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: limit.sent = 0 def _on_handshake_done_delivery(self, delivery: QuicDeliveryState) -> None: """ Callback when a HANDSHAKE_DONE frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: self._handshake_done_pending = True def _on_max_stream_data_delivery( self, delivery: QuicDeliveryState, stream: QuicStream ) -> None: """ Callback when a MAX_STREAM_DATA frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: stream.max_stream_data_local_sent = 0 self._streams_dirty_limits.add(stream) def _on_new_connection_id_delivery( self, delivery: QuicDeliveryState, connection_id: QuicConnectionId ) -> None: """ Callback when a NEW_CONNECTION_ID frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: connection_id.was_sent = False def _on_ping_delivery( self, delivery: QuicDeliveryState, uids: Sequence[int] ) -> None: """ Callback when a PING frame is acknowledged or lost. """ if delivery == QuicDeliveryState.ACKED: self._logger.debug("Received PING%s response", "" if uids else " (probe)") for uid in uids: self._events.append(events.PingAcknowledged(uid=uid)) else: self._ping_pending.extend(uids) def _on_mtu_probe_delivery( self, delivery: QuicDeliveryState, probe_size: int ) -> None: """ Callback when an MTU probe PING frame is acknowledged or lost. """ if delivery == QuicDeliveryState.ACKED: self._logger.debug("MTU probe ACK'd, datagram size now %d", probe_size) self._max_datagram_size = probe_size self._loss._cc._max_datagram_size = probe_size if self._mtu_probe_sizes and self._mtu_probe_sizes[0] == probe_size: self._mtu_probe_sizes.pop(0) self._mtu_probe_pending = None else: self._logger.debug("MTU probe for %d lost, stopping", probe_size) self._mtu_probe_sizes.clear() self._mtu_probe_pending = None def _on_retire_connection_id_delivery( self, delivery: QuicDeliveryState, sequence_number: int ) -> None: """ Callback when a RETIRE_CONNECTION_ID frame is acknowledged or lost. """ if delivery != QuicDeliveryState.ACKED: self._retire_connection_ids.append(sequence_number) def _payload_received( self, context: QuicReceiveContext, plain: bytes, crypto_frame_required: bool = False, ) -> tuple[bool, bool]: """ Handle a QUIC packet payload. """ buf = Buffer(data=plain) crypto_frame_found = False frame_found = False is_ack_eliciting = False is_probing = None while not buf.eof(): # get frame type try: frame_type = buf.pull_uint_var() except BufferReadError: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=None, reason_phrase="Malformed frame type", ) # check frame type is known try: frame_handler, frame_epochs = self.__frame_handlers[frame_type] except KeyError: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Unknown frame type", ) # check frame type is allowed for the epoch if context.epoch not in frame_epochs: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=frame_type, reason_phrase="Unexpected frame type", ) # handle the frame try: frame_handler(context, frame_type, buf) except BufferReadError: raise QuicConnectionError( error_code=QuicErrorCode.FRAME_ENCODING_ERROR, frame_type=frame_type, reason_phrase="Failed to parse frame", ) except StreamFinishedError: # we lack the state for the stream, ignore the frame pass # update ACK only / probing flags frame_found = True if frame_type == QuicFrameType.CRYPTO: crypto_frame_found = True if frame_type not in NON_ACK_ELICITING_FRAME_TYPES: is_ack_eliciting = True if frame_type not in PROBING_FRAME_TYPES: is_probing = False elif is_probing is None: is_probing = True if not frame_found: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Packet contains no frames", ) # RFC 9000 - 17.2.2. Initial Packet # The first packet sent by a client always includes a CRYPTO frame. if crypto_frame_required and not crypto_frame_found: raise QuicConnectionError( error_code=QuicErrorCode.PROTOCOL_VIOLATION, frame_type=QuicFrameType.PADDING, reason_phrase="Packet contains no CRYPTO frame", ) return is_ack_eliciting, bool(is_probing) def _receive_retry_packet( self, header: QuicHeader, packet_without_tag: bytes, now: float ) -> None: """ Handle a retry packet. """ if ( self._is_client and not self._retry_count and header.destination_cid == self.host_cid and header.integrity_tag == get_retry_integrity_tag( packet_without_tag, self._peer_cid.cid, version=header.version, ) ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_received", data={ "frames": [], "header": { "packet_type": "retry", "scid": dump_cid(header.source_cid), "dcid": dump_cid(header.destination_cid), }, "raw": {"length": header.packet_length}, }, ) self._peer_cid.cid = header.source_cid self._peer_token = header.token self._retry_count += 1 self._retry_source_connection_id = header.source_cid self._logger.debug(f"Retrying with token ({len(header.token)} bytes)") self._connect(now=now) else: # Unexpected or invalid retry packet. if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unexpected_packet", "raw": {"length": header.packet_length}, }, ) def _receive_version_negotiation_packet( self, header: QuicHeader, now: float ) -> None: """ Handle a version negotiation packet. This is used in "Incompatible Version Negotiation", see: https://datatracker.ietf.org/doc/html/rfc9368#section-2.2 """ # Only clients process Version Negotiation, and once a Version # Negotiation packet has been acted upon, any further # such packets must be ignored. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if ( self._is_client and self._state is QuicConnectionState.FIRSTFLIGHT and not self._version_negotiated_incompatible ): if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_received", data={ "frames": [], "header": { "packet_type": self._quic_logger.packet_type( header.packet_type ), "scid": dump_cid(header.source_cid), "dcid": dump_cid(header.destination_cid), }, "raw": {"length": header.packet_length}, }, ) # Ignore any Version Negotiation packets that contain the # original version. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if self._version in header.supported_versions: self._logger.debug( "Version negotiation packet contains protocol version %s", pretty_protocol_version(self._version), ) return # Look for a common protocol version. common = [ x for x in self._configuration.supported_versions if x in header.supported_versions ] # Look for a common protocol version. chosen_version = common[0] if common else None if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="version_information", data={ "server_versions": header.supported_versions, "client_versions": self._configuration.supported_versions, "chosen_version": chosen_version, }, ) if chosen_version is None: self._logger.debug("Could not find a common protocol version") self._close_event = events.ConnectionTerminated( error_code=QuicErrorCode.INTERNAL_ERROR, frame_type=QuicFrameType.PADDING, reason_phrase="Could not find a common protocol version", ) self._close_end() return self._packet_number = 0 self._version = chosen_version self._version_negotiated_incompatible = True self._logger.debug( "Retrying with protocol version %s", pretty_protocol_version(self._version), ) self._connect(now=now) else: # Unexpected version negotiation packet. if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="packet_dropped", data={ "trigger": "unexpected_packet", "raw": {"length": header.packet_length}, }, ) def _replenish_connection_ids(self) -> None: """ Generate new connection IDs. """ while len(self._host_cids) < min(8, self._remote_active_connection_id_limit): self._host_cids.append( QuicConnectionId( cid=os.urandom(self._configuration.connection_id_length), sequence_number=self._host_cid_seq, stateless_reset_token=os.urandom(16), ) ) self._host_cid_seq += 1 def _retire_peer_cid(self, connection_id: QuicConnectionId) -> None: """ Retire a destination connection ID. """ self._logger.debug( "Retiring CID %s (%d) [%d]", dump_cid(connection_id.cid), connection_id.sequence_number, len(self._retire_connection_ids) + 1, ) self._retire_connection_ids.append(connection_id.sequence_number) def _push_crypto_data(self) -> None: for epoch, buf in self._crypto_buffers.items(): self._crypto_streams[epoch].sender.write(buf.data) buf.seek(0) def _send_probe(self) -> None: self._probe_pending = True def _parse_transport_parameters( self, data: bytes, from_session_ticket: bool = False ) -> None: """ Parse and apply remote transport parameters. `from_session_ticket` is `True` when restoring saved transport parameters, and `False` when handling received transport parameters. """ try: quic_transport_parameters = pull_quic_transport_parameters( Buffer(data=data) ) except ValueError: raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="Could not parse QUIC transport parameters", ) # log event if self._quic_logger is not None and not from_session_ticket: self._quic_logger.log_event( category="transport", event="parameters_set", data=self._quic_logger.encode_transport_parameters( owner="remote", parameters=quic_transport_parameters ), ) # validate remote parameters if not self._is_client: for attr in [ "original_destination_connection_id", "preferred_address", "retry_source_connection_id", "stateless_reset_token", ]: if getattr(quic_transport_parameters, attr) is not None: raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase=f"{attr} is not allowed for clients", ) if not from_session_ticket: if ( quic_transport_parameters.initial_source_connection_id != self._remote_initial_source_connection_id ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="initial_source_connection_id does not match", ) if self._is_client and ( quic_transport_parameters.original_destination_connection_id != self._original_destination_connection_id ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="original_destination_connection_id does not match", ) if self._is_client and ( quic_transport_parameters.retry_source_connection_id != self._retry_source_connection_id ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="retry_source_connection_id does not match", ) if ( quic_transport_parameters.active_connection_id_limit is not None and quic_transport_parameters.active_connection_id_limit < 2 ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="active_connection_id_limit must be no less than 2", ) if ( quic_transport_parameters.ack_delay_exponent is not None and quic_transport_parameters.ack_delay_exponent > 20 ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="ack_delay_exponent must be <= 20", ) if ( quic_transport_parameters.max_ack_delay is not None and quic_transport_parameters.max_ack_delay >= 2**14 ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="max_ack_delay must be < 2^14", ) if ( quic_transport_parameters.max_udp_payload_size is not None and quic_transport_parameters.max_udp_payload_size < SMALLEST_MAX_DATAGRAM_SIZE ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase="max_udp_payload_size must " f"be >= {SMALLEST_MAX_DATAGRAM_SIZE}", ) # Validate Version Information extension. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if quic_transport_parameters.version_information is not None: version_information = quic_transport_parameters.version_information # If a server receives Version Information where the Chosen Version # is not included in Available Versions, it MUST treat is as a # parsing failure. if ( not self._is_client and version_information.chosen_version not in version_information.available_versions ): raise QuicConnectionError( error_code=QuicErrorCode.TRANSPORT_PARAMETER_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase=( "version_information's chosen_version is not included " "in available_versions" ), ) # Validate that the Chosen Version matches the version in use for the # connection. if version_information.chosen_version != self._crypto_packet_version: raise QuicConnectionError( error_code=QuicErrorCode.VERSION_NEGOTIATION_ERROR, frame_type=QuicFrameType.CRYPTO, reason_phrase=( "version_information's chosen_version does not match " "the version in use" ), ) # store remote parameters if not from_session_ticket: if quic_transport_parameters.ack_delay_exponent is not None: self._remote_ack_delay_exponent = self._remote_ack_delay_exponent if quic_transport_parameters.max_ack_delay is not None: self._loss.max_ack_delay = ( quic_transport_parameters.max_ack_delay / 1000.0 ) if ( self._is_client and self._peer_cid.sequence_number == 0 and quic_transport_parameters.stateless_reset_token is not None ): self._peer_cid.stateless_reset_token = ( quic_transport_parameters.stateless_reset_token ) self._remote_version_information = ( quic_transport_parameters.version_information ) if quic_transport_parameters.active_connection_id_limit is not None: self._remote_active_connection_id_limit = ( quic_transport_parameters.active_connection_id_limit ) if quic_transport_parameters.max_idle_timeout is not None: self._remote_max_idle_timeout = ( quic_transport_parameters.max_idle_timeout / 1000.0 ) self._remote_max_datagram_frame_size = ( quic_transport_parameters.max_datagram_frame_size ) for param in [ "max_data", "max_stream_data_bidi_local", "max_stream_data_bidi_remote", "max_stream_data_uni", "max_streams_bidi", "max_streams_uni", ]: value = getattr(quic_transport_parameters, "initial_" + param) if value is not None: setattr(self, "_remote_" + param, value) # Cap MTU probe sizes to the peer's max_udp_payload_size. if ( self._mtu_probe_sizes and quic_transport_parameters.max_udp_payload_size is not None ): peer_max = quic_transport_parameters.max_udp_payload_size capped: list[int] = [] for s in self._mtu_probe_sizes: size = min(s, peer_max) if size > self._max_datagram_size and ( not capped or size != capped[-1] ): capped.append(size) self._mtu_probe_sizes = capped def _serialize_transport_parameters(self) -> bytes: quic_transport_parameters = QuicTransportParameters( ack_delay_exponent=self._local_ack_delay_exponent, active_connection_id_limit=self._local_active_connection_id_limit, max_idle_timeout=int(self._configuration.idle_timeout * 1000), initial_max_data=self._local_max_data.value, initial_max_stream_data_bidi_local=self._local_max_stream_data_bidi_local, initial_max_stream_data_bidi_remote=self._local_max_stream_data_bidi_remote, initial_max_stream_data_uni=self._local_max_stream_data_uni, initial_max_streams_bidi=self._local_max_streams_bidi.value, initial_max_streams_uni=self._local_max_streams_uni.value, initial_source_connection_id=self._local_initial_source_connection_id, max_ack_delay=25, max_datagram_frame_size=self._configuration.max_datagram_frame_size, quantum_readiness=( b"Q" * SMALLEST_MAX_DATAGRAM_SIZE if self._configuration.quantum_readiness_test else None ), stateless_reset_token=self._host_cids[0].stateless_reset_token, version_information=QuicVersionInformation( chosen_version=self._version, available_versions=self._configuration.supported_versions, ), ) if not self._is_client: quic_transport_parameters.original_destination_connection_id = ( self._original_destination_connection_id ) quic_transport_parameters.retry_source_connection_id = ( self._retry_source_connection_id ) # log event if self._quic_logger is not None: self._quic_logger.log_event( category="transport", event="parameters_set", data=self._quic_logger.encode_transport_parameters( owner="local", parameters=quic_transport_parameters ), ) buf = Buffer(capacity=3 * self._max_datagram_size) push_quic_transport_parameters(buf, quic_transport_parameters) return buf.data def _set_state(self, state: QuicConnectionState) -> None: self._logger.debug("%s -> %s", self._state.name, state.name) self._state = state def _unblock_streams(self, is_unidirectional: bool) -> None: if is_unidirectional: max_stream_data_remote = self._remote_max_stream_data_uni max_streams = self._remote_max_streams_uni streams_blocked = self._streams_blocked_uni else: max_stream_data_remote = self._remote_max_stream_data_bidi_remote max_streams = self._remote_max_streams_bidi streams_blocked = self._streams_blocked_bidi while streams_blocked and streams_blocked[0].stream_id // 4 < max_streams: stream = streams_blocked.pop(0) stream.is_blocked = False stream.max_stream_data_remote = max_stream_data_remote if not self._streams_blocked_bidi and not self._streams_blocked_uni: self._streams_blocked_pending = False def _update_traffic_key( self, direction: tls.Direction, epoch: tls.Epoch, cipher_suite: tls.CipherSuite, secret: bytes, ) -> None: """ Callback which is invoked by the TLS engine when new traffic keys are available. """ # For clients, determine the negotiated protocol version. if ( self._is_client and self._crypto_packet_version is not None and not self._version_negotiated_compatible ): self._version = self._crypto_packet_version self._version_negotiated_compatible = True self._logger.debug( "Negotiated protocol version %s", pretty_protocol_version(self._version) ) secrets_log_file = self._configuration.secrets_log_file if secrets_log_file is not None: label_row = self._is_client == (direction == tls.Direction.DECRYPT) label = SECRETS_LABELS[label_row][epoch.value] secrets_log_file.write( f"{label} {self.tls.client_random.hex()} {secret.hex()}\n" ) secrets_log_file.flush() crypto = self._cryptos[epoch] if direction == tls.Direction.ENCRYPT: crypto.send.setup( cipher_suite=cipher_suite, secret=secret, version=self._version ) else: crypto.recv.setup( cipher_suite=cipher_suite, secret=secret, version=self._version ) def _add_local_challenge(self, challenge: bytes, network_path: QuicNetworkPath): self._local_challenges[challenge] = network_path while len(self._local_challenges) > MAX_LOCAL_CHALLENGES: # Dictionaries are ordered, so pop the first key until we are below the # limit. key = next(iter(self._local_challenges.keys())) del self._local_challenges[key] def _write_application( self, builder: QuicPacketBuilder, network_path: QuicNetworkPath, now: float ) -> None: crypto_stream: QuicStream | None = None if self._cryptos[tls.Epoch.ONE_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ONE_RTT] crypto_stream = self._crypto_streams[tls.Epoch.ONE_RTT] packet_type = QuicPacketType.ONE_RTT elif self._cryptos[tls.Epoch.ZERO_RTT].send.is_valid(): crypto = self._cryptos[tls.Epoch.ZERO_RTT] packet_type = QuicPacketType.ZERO_RTT else: return space = self._spaces[tls.Epoch.ONE_RTT] while True: # apply pacing, except if we have ACKs to send if space.ack_at is None or space.ack_at >= now: self._pacing_at = self._loss._pacer.next_send_time(now=now) if self._pacing_at is not None: break builder.start_packet(packet_type, crypto) if self._handshake_complete: # PATH CHALLENGE if not (network_path.is_validated or network_path.local_challenge_sent): challenge = os.urandom(8) self._write_path_challenge_frame( builder=builder, challenge=challenge ) self._add_local_challenge( challenge=challenge, network_path=network_path ) network_path.local_challenge_sent = True # ACK if space.ack_at is not None and space.ack_at <= now: self._write_ack_frame(builder=builder, space=space, now=now) # HANDSHAKE_DONE if self._handshake_done_pending: self._write_handshake_done_frame(builder=builder) self._handshake_done_pending = False # PATH RESPONSE while network_path.remote_challenges: self._write_path_response_frame( builder=builder, challenge=network_path.remote_challenges[0] ) network_path.remote_challenges.popleft() # NEW_CONNECTION_ID for connection_id in self._host_cids: if not connection_id.was_sent: self._write_new_connection_id_frame( builder=builder, connection_id=connection_id ) # RETIRE_CONNECTION_ID if self._retire_connection_ids: for sequence_number in self._retire_connection_ids: self._write_retire_connection_id_frame( builder=builder, sequence_number=sequence_number ) self._retire_connection_ids.clear() # STREAMS_BLOCKED if self._streams_blocked_pending: if self._streams_blocked_bidi: self._write_streams_blocked_frame( builder=builder, frame_type=QuicFrameType.STREAMS_BLOCKED_BIDI, limit=self._remote_max_streams_bidi, ) if self._streams_blocked_uni: self._write_streams_blocked_frame( builder=builder, frame_type=QuicFrameType.STREAMS_BLOCKED_UNI, limit=self._remote_max_streams_uni, ) self._streams_blocked_pending = False # MAX_DATA and MAX_STREAMS self._write_connection_limits(builder=builder, space=space) # stream-level limits if self._streams_dirty_limits: for stream in self._streams_dirty_limits: self._write_stream_limits( builder=builder, space=space, stream=stream ) self._streams_dirty_limits.clear() # PING (user-request) if self._ping_pending: self._write_ping_frame(builder, self._ping_pending) self._ping_pending.clear() # PING (probe) if self._probe_pending: self._write_ping_frame(builder, comment="probe") self._probe_pending = False # CRYPTO if crypto_stream is not None and not crypto_stream.sender.buffer_is_empty: self._write_crypto_frame( builder=builder, space=space, stream=crypto_stream ) # DATAGRAM while self._datagrams_pending: datagram_pending = self._datagrams_pending.popleft() try: self._write_datagram_frame( builder=builder, data=datagram_pending, frame_type=QuicFrameType.DATAGRAM_WITH_LENGTH, ) except QuicPacketBuilderStop: self._datagrams_pending.appendleft(datagram_pending) break queue = self._streams_queue sent: list[QuicStream] = [] write_idx = 0 try: for stream in queue: # if the stream is finished, discard it if stream.is_finished: self._logger.debug(f"Stream {stream.stream_id} discarded") del self._streams[stream.stream_id] self._streams_finished.add(stream.stream_id) self._streams_dirty_limits.discard(stream) continue if stream.receiver.stop_pending: # STOP_SENDING self._write_stop_sending_frame(builder=builder, stream=stream) if stream.sender.reset_pending: # RESET_STREAM self._write_reset_stream_frame(builder=builder, stream=stream) elif not stream.is_blocked and not stream.sender.buffer_is_empty: # STREAM used = self._write_stream_frame( builder=builder, space=space, stream=stream, max_offset=min( stream.sender.highest_offset + self._remote_max_data - self._remote_max_data_used, stream.max_stream_data_remote, ), ) self._remote_max_data_used += used if used > 0: sent.append(stream) continue queue[write_idx] = stream write_idx += 1 finally: # Compact in-place: reshelved streams are at queue[0:write_idx], # trim the rest and append sent streams to the end for fairness. del queue[write_idx:] queue.extend(sent) if builder.packet_is_empty: break else: self._loss._pacer.update_after_send(now=now) def _write_handshake( self, builder: QuicPacketBuilder, epoch: tls.Epoch, now: float ) -> None: crypto = self._cryptos[epoch] if not crypto.send.is_valid(): return crypto_stream = self._crypto_streams[epoch] space = self._spaces[epoch] while True: if epoch == tls.Epoch.INITIAL: packet_type = QuicPacketType.INITIAL else: packet_type = QuicPacketType.HANDSHAKE builder.start_packet(packet_type, crypto) # ACK if space.ack_at is not None: self._write_ack_frame(builder=builder, space=space, now=now) # CRYPTO if not crypto_stream.sender.buffer_is_empty: if self._write_crypto_frame( builder=builder, space=space, stream=crypto_stream ): self._probe_pending = False # PING (probe) if ( self._probe_pending and not self._handshake_complete and ( epoch == tls.Epoch.HANDSHAKE or not self._cryptos[tls.Epoch.HANDSHAKE].send.is_valid() ) ): self._write_ping_frame(builder, comment="probe") self._probe_pending = False if builder.packet_is_empty: break def _write_ack_frame( self, builder: QuicPacketBuilder, space: QuicPacketSpace, now: float ) -> None: # calculate ACK delay ack_delay = now - space.largest_received_time ack_delay_encoded = int(ack_delay * 1000000) >> self._local_ack_delay_exponent buf = builder.start_frame( QuicFrameType.ACK, capacity=ACK_FRAME_CAPACITY, handler=self._on_ack_delivery, handler_args=(space, space.largest_received_packet), ) ranges = push_ack_frame(buf, space.ack_queue, ack_delay_encoded) space.ack_at = None # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_ack_frame( ranges=space.ack_queue, delay=ack_delay ) ) # check if we need to trigger an ACK-of-ACK if ranges > 1 and builder.packet_number % 8 == 0: self._write_ping_frame(builder, comment="ACK-of-ACK trigger") def _write_connection_close_frame( self, builder: QuicPacketBuilder, epoch: tls.Epoch, error_code: int, frame_type: int | None, reason_phrase: str, ) -> None: # convert application-level close to transport-level close in early stages if frame_type is None and epoch in (tls.Epoch.INITIAL, tls.Epoch.HANDSHAKE): error_code = QuicErrorCode.APPLICATION_ERROR frame_type = QuicFrameType.PADDING reason_phrase = "" reason_bytes = reason_phrase.encode("utf8") reason_length = len(reason_bytes) if frame_type is None: buf = builder.start_frame( QuicFrameType.APPLICATION_CLOSE, capacity=APPLICATION_CLOSE_FRAME_CAPACITY + reason_length, ) buf.push_uint_var(error_code) buf.push_uint_var(reason_length) buf.push_bytes(reason_bytes) else: buf = builder.start_frame( QuicFrameType.TRANSPORT_CLOSE, capacity=TRANSPORT_CLOSE_FRAME_CAPACITY + reason_length, ) buf.push_uint_var(error_code) buf.push_uint_var(frame_type) buf.push_uint_var(reason_length) buf.push_bytes(reason_bytes) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_connection_close_frame( error_code=error_code, frame_type=frame_type, reason_phrase=reason_phrase, ) ) def _write_connection_limits( self, builder: QuicPacketBuilder, space: QuicPacketSpace ) -> None: """ Raise MAX_DATA or MAX_STREAMS if needed. """ for limit in ( self._local_max_data, self._local_max_streams_bidi, self._local_max_streams_uni, ): if limit.used * 2 > limit.value: limit.value *= 2 self._logger.debug("Local %s raised to %d", limit.name, limit.value) if limit.value != limit.sent: buf = builder.start_frame( limit.frame_type, capacity=CONNECTION_LIMIT_FRAME_CAPACITY, handler=self._on_connection_limit_delivery, handler_args=(limit,), ) buf.push_uint_var(limit.value) limit.sent = limit.value # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_connection_limit_frame( frame_type=limit.frame_type, maximum=limit.value, ) ) def _write_crypto_frame( self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream ) -> bool: frame_overhead = 3 + size_uint_var(stream.sender.next_offset) frame = stream.sender.get_frame(builder.remaining_flight_space - frame_overhead) if frame is not None: buf = builder.start_frame( QuicFrameType.CRYPTO, capacity=frame_overhead, handler=stream.sender.on_data_delivery, handler_args=(frame.offset, frame.offset + len(frame.data)), ) buf.push_uint_var(frame.offset) buf.push_uint16(len(frame.data) | 0x4000) buf.push_bytes(frame.data) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_crypto_frame(frame) ) return True return False def _write_datagram_frame( self, builder: QuicPacketBuilder, data: bytes, frame_type: QuicFrameType ) -> bool: """ Write a DATAGRAM frame. Returns True if the frame was processed, False otherwise. """ assert frame_type == QuicFrameType.DATAGRAM_WITH_LENGTH length = len(data) frame_size = 1 + size_uint_var(length) + length buf = builder.start_frame(frame_type, capacity=frame_size) buf.push_uint_var(length) buf.push_bytes(data) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_datagram_frame(length=length) ) return True def _write_handshake_done_frame(self, builder: QuicPacketBuilder) -> None: builder.start_frame( QuicFrameType.HANDSHAKE_DONE, capacity=HANDSHAKE_DONE_FRAME_CAPACITY, handler=self._on_handshake_done_delivery, ) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_handshake_done_frame() ) def _write_new_connection_id_frame( self, builder: QuicPacketBuilder, connection_id: QuicConnectionId ) -> None: retire_prior_to = 0 # FIXME buf = builder.start_frame( QuicFrameType.NEW_CONNECTION_ID, capacity=NEW_CONNECTION_ID_FRAME_CAPACITY, handler=self._on_new_connection_id_delivery, handler_args=(connection_id,), ) buf.push_uint_var(connection_id.sequence_number) buf.push_uint_var(retire_prior_to) buf.push_uint8(len(connection_id.cid)) buf.push_bytes(connection_id.cid) buf.push_bytes(connection_id.stateless_reset_token) connection_id.was_sent = True self._events.append(events.ConnectionIdIssued(connection_id=connection_id.cid)) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_new_connection_id_frame( connection_id=connection_id.cid, retire_prior_to=retire_prior_to, sequence_number=connection_id.sequence_number, stateless_reset_token=connection_id.stateless_reset_token, ) ) def _write_path_challenge_frame( self, builder: QuicPacketBuilder, challenge: bytes ) -> None: buf = builder.start_frame( QuicFrameType.PATH_CHALLENGE, capacity=PATH_CHALLENGE_FRAME_CAPACITY ) buf.push_bytes(challenge) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_path_challenge_frame(data=challenge) ) def _write_path_response_frame( self, builder: QuicPacketBuilder, challenge: bytes ) -> None: buf = builder.start_frame( QuicFrameType.PATH_RESPONSE, capacity=PATH_RESPONSE_FRAME_CAPACITY ) buf.push_bytes(challenge) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_path_response_frame(data=challenge) ) def _write_ping_frame( self, builder: QuicPacketBuilder, uids: list[int] = [], comment="" ): builder.start_frame( QuicFrameType.PING, capacity=PING_FRAME_CAPACITY, handler=self._on_ping_delivery, handler_args=(tuple(uids),), ) self._logger.debug( "Sending PING%s in packet %d", f" ({comment})" if comment else "", builder.packet_number, ) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append(self._quic_logger.encode_ping_frame()) def _write_reset_stream_frame( self, builder: QuicPacketBuilder, stream: QuicStream, ) -> None: buf = builder.start_frame( frame_type=QuicFrameType.RESET_STREAM, capacity=RESET_STREAM_FRAME_CAPACITY, handler=stream.sender.on_reset_delivery, ) frame = stream.sender.get_reset_frame() buf.push_uint_var(frame.stream_id) buf.push_uint_var(frame.error_code) buf.push_uint_var(frame.final_size) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_reset_stream_frame( error_code=frame.error_code, final_size=frame.final_size, stream_id=frame.stream_id, ) ) def _write_retire_connection_id_frame( self, builder: QuicPacketBuilder, sequence_number: int ) -> None: buf = builder.start_frame( QuicFrameType.RETIRE_CONNECTION_ID, capacity=RETIRE_CONNECTION_ID_CAPACITY, handler=self._on_retire_connection_id_delivery, handler_args=(sequence_number,), ) buf.push_uint_var(sequence_number) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_retire_connection_id_frame(sequence_number) ) def _write_stop_sending_frame( self, builder: QuicPacketBuilder, stream: QuicStream, ) -> None: buf = builder.start_frame( frame_type=QuicFrameType.STOP_SENDING, capacity=STOP_SENDING_FRAME_CAPACITY, handler=stream.receiver.on_stop_sending_delivery, ) frame = stream.receiver.get_stop_frame() buf.push_uint_var(frame.stream_id) buf.push_uint_var(frame.error_code) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_stop_sending_frame( error_code=frame.error_code, stream_id=frame.stream_id ) ) def _write_stream_frame( self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream, max_offset: int, ) -> int: # the frame data size is constrained by our peer's MAX_DATA and # the space available in the current packet frame_overhead = ( 3 + size_uint_var(stream.stream_id) + ( size_uint_var(stream.sender.next_offset) if stream.sender.next_offset else 0 ) ) previous_send_highest = stream.sender.highest_offset frame = stream.sender.get_frame( builder.remaining_flight_space - frame_overhead, max_offset ) if frame is not None: frame_type = QuicFrameType.STREAM_BASE | 2 # length if frame.offset: frame_type |= 4 if frame.fin: frame_type |= 1 buf = builder.start_frame( frame_type, capacity=frame_overhead, handler=stream.sender.on_data_delivery, handler_args=(frame.offset, frame.offset + len(frame.data)), ) buf.push_uint_var(stream.stream_id) if frame.offset: buf.push_uint_var(frame.offset) buf.push_uint16(len(frame.data) | 0x4000) buf.push_bytes(frame.data) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_stream_frame( frame, stream_id=stream.stream_id ) ) return stream.sender.highest_offset - previous_send_highest else: return 0 def _write_stream_limits( self, builder: QuicPacketBuilder, space: QuicPacketSpace, stream: QuicStream ) -> None: """ Raise MAX_STREAM_DATA if needed. The only case where `stream.max_stream_data_local` is zero is for locally created unidirectional streams. We skip such streams to avoid spurious logging. """ if ( stream.max_stream_data_local and stream.receiver.highest_offset * 2 > stream.max_stream_data_local ): stream.max_stream_data_local *= 2 self._logger.debug( "Stream %d local max_stream_data raised to %d", stream.stream_id, stream.max_stream_data_local, ) if stream.max_stream_data_local_sent != stream.max_stream_data_local: buf = builder.start_frame( QuicFrameType.MAX_STREAM_DATA, capacity=MAX_STREAM_DATA_FRAME_CAPACITY, handler=self._on_max_stream_data_delivery, handler_args=(stream,), ) buf.push_uint_var(stream.stream_id) buf.push_uint_var(stream.max_stream_data_local) stream.max_stream_data_local_sent = stream.max_stream_data_local # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_max_stream_data_frame( maximum=stream.max_stream_data_local, stream_id=stream.stream_id ) ) def _write_streams_blocked_frame( self, builder: QuicPacketBuilder, frame_type: QuicFrameType, limit: int ) -> None: buf = builder.start_frame(frame_type, capacity=STREAMS_BLOCKED_CAPACITY) buf.push_uint_var(limit) # log frame if self._quic_logger is not None: builder.quic_logger_frames.append( self._quic_logger.encode_streams_blocked_frame( is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, limit=limit, ) )