from __future__ import annotations import binascii import ipaddress import os from dataclasses import dataclass from enum import IntEnum from .._compat import DATACLASS_KWARGS from .._hazmat import AeadAes128Gcm, Buffer, RangeSet PACKET_LONG_HEADER = 0x80 PACKET_FIXED_BIT = 0x40 PACKET_SPIN_BIT = 0x20 CONNECTION_ID_MAX_SIZE = 20 PACKET_NUMBER_MAX_SIZE = 4 RETRY_AEAD_KEY_VERSION_1 = binascii.unhexlify("be0c690b9f66575a1d766b54e368c84e") RETRY_AEAD_KEY_VERSION_2 = binascii.unhexlify("8fb4b01b56ac48e260fbcbcead7ccc92") RETRY_AEAD_NONCE_VERSION_1 = binascii.unhexlify("461599d35d632bf2239825bb") RETRY_AEAD_NONCE_VERSION_2 = binascii.unhexlify("d86969bc2d7c6d9990efb04a") RETRY_INTEGRITY_TAG_SIZE = 16 STATELESS_RESET_TOKEN_SIZE = 16 class QuicErrorCode(IntEnum): NO_ERROR = 0x0 INTERNAL_ERROR = 0x1 CONNECTION_REFUSED = 0x2 FLOW_CONTROL_ERROR = 0x3 STREAM_LIMIT_ERROR = 0x4 STREAM_STATE_ERROR = 0x5 FINAL_SIZE_ERROR = 0x6 FRAME_ENCODING_ERROR = 0x7 TRANSPORT_PARAMETER_ERROR = 0x8 CONNECTION_ID_LIMIT_ERROR = 0x9 PROTOCOL_VIOLATION = 0xA INVALID_TOKEN = 0xB APPLICATION_ERROR = 0xC CRYPTO_BUFFER_EXCEEDED = 0xD KEY_UPDATE_ERROR = 0xE AEAD_LIMIT_REACHED = 0xF VERSION_NEGOTIATION_ERROR = 0x11 CRYPTO_ERROR = 0x100 class QuicPacketType(IntEnum): INITIAL = 0 ZERO_RTT = 1 HANDSHAKE = 2 RETRY = 3 VERSION_NEGOTIATION = 4 ONE_RTT = 5 # For backwards compatibility only, use `QuicPacketType` in new code. PACKET_TYPE_INITIAL = QuicPacketType.INITIAL # QUIC version 1 # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 PACKET_LONG_TYPE_ENCODE_VERSION_1 = { QuicPacketType.INITIAL: 0, QuicPacketType.ZERO_RTT: 1, QuicPacketType.HANDSHAKE: 2, QuicPacketType.RETRY: 3, } PACKET_LONG_TYPE_DECODE_VERSION_1 = { v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_1.items() } # QUIC version 2 # https://datatracker.ietf.org/doc/html/rfc9369#section-3.2 PACKET_LONG_TYPE_ENCODE_VERSION_2 = { QuicPacketType.INITIAL: 1, QuicPacketType.ZERO_RTT: 2, QuicPacketType.HANDSHAKE: 3, QuicPacketType.RETRY: 0, } PACKET_LONG_TYPE_DECODE_VERSION_2 = { v: i for (i, v) in PACKET_LONG_TYPE_ENCODE_VERSION_2.items() } class QuicProtocolVersion(IntEnum): NEGOTIATION = 0 VERSION_1 = 0x00000001 VERSION_2 = 0x6B3343CF @dataclass(**DATACLASS_KWARGS) class QuicHeader: version: int | None "The protocol version. Only present in long header packets." packet_type: QuicPacketType "The type of the packet." packet_length: int "The total length of the packet, in bytes." destination_cid: bytes "The destination connection ID." source_cid: bytes "The destination connection ID." token: bytes "The address verification token. Only present in `INITIAL` and `RETRY` packets." integrity_tag: bytes "The retry integrity tag. Only present in `RETRY` packets." supported_versions: list[int] "Supported protocol versions. Only present in `VERSION_NEGOTIATION` packets." def get_retry_integrity_tag( packet_without_tag: bytes, original_destination_cid: bytes, version: int ) -> bytes: """ Calculate the integrity tag for a RETRY packet. """ # build Retry pseudo packet buf = Buffer(capacity=1 + len(original_destination_cid) + len(packet_without_tag)) buf.push_uint8(len(original_destination_cid)) buf.push_bytes(original_destination_cid) buf.push_bytes(packet_without_tag) assert buf.eof() if version == QuicProtocolVersion.VERSION_2: aead_key = RETRY_AEAD_KEY_VERSION_2 aead_nonce = RETRY_AEAD_NONCE_VERSION_2 else: aead_key = RETRY_AEAD_KEY_VERSION_1 aead_nonce = RETRY_AEAD_NONCE_VERSION_1 # run AES-128-GCM aead = AeadAes128Gcm(aead_key, b"null!12bytes") integrity_tag = aead.encrypt_with_nonce(aead_nonce, b"", buf.data) assert len(integrity_tag) == RETRY_INTEGRITY_TAG_SIZE return integrity_tag def get_spin_bit(first_byte: int) -> bool: if first_byte & PACKET_SPIN_BIT: return True return False def is_long_header(first_byte: int) -> bool: if first_byte & PACKET_LONG_HEADER: return True return False def pretty_protocol_version(version: int) -> str: """ Return a user-friendly representation of a protocol version. """ try: version_name = QuicProtocolVersion(version).name except ValueError: version_name = "UNKNOWN" return f"0x{version:08x} ({version_name})" def pull_quic_header(buf: Buffer, host_cid_length: int | None = None) -> QuicHeader: packet_start = buf.tell() version: int | None integrity_tag = b"" supported_versions = [] token = b"" first_byte = buf.pull_uint8() if is_long_header(first_byte): # Long Header Packets. # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2 version = buf.pull_uint32() destination_cid_length = buf.pull_uint8() if destination_cid_length > CONNECTION_ID_MAX_SIZE: raise ValueError( f"Destination CID is too long ({destination_cid_length} bytes)" ) destination_cid = buf.pull_bytes(destination_cid_length) source_cid_length = buf.pull_uint8() if source_cid_length > CONNECTION_ID_MAX_SIZE: raise ValueError(f"Source CID is too long ({source_cid_length} bytes)") source_cid = buf.pull_bytes(source_cid_length) if version == QuicProtocolVersion.NEGOTIATION: # Version Negotiation Packet. # https://datatracker.ietf.org/doc/html/rfc9000#section-17.2.1 packet_type = QuicPacketType.VERSION_NEGOTIATION while not buf.eof(): supported_versions.append(buf.pull_uint32()) packet_end = buf.tell() else: if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") if version == QuicProtocolVersion.VERSION_2: packet_type = PACKET_LONG_TYPE_DECODE_VERSION_2[ (first_byte & 0x30) >> 4 ] else: packet_type = PACKET_LONG_TYPE_DECODE_VERSION_1[ (first_byte & 0x30) >> 4 ] if packet_type == QuicPacketType.INITIAL: token_length = buf.pull_uint_var() token = buf.pull_bytes(token_length) rest_length = buf.pull_uint_var() elif packet_type == QuicPacketType.ZERO_RTT: rest_length = buf.pull_uint_var() elif packet_type == QuicPacketType.HANDSHAKE: rest_length = buf.pull_uint_var() else: token_length = buf.capacity - buf.tell() - RETRY_INTEGRITY_TAG_SIZE token = buf.pull_bytes(token_length) integrity_tag = buf.pull_bytes(RETRY_INTEGRITY_TAG_SIZE) rest_length = 0 # Check remainder length. packet_end = buf.tell() + rest_length if packet_end > buf.capacity: raise ValueError("Packet payload is truncated") else: # https://datatracker.ietf.org/doc/html/rfc9000#section-17.3 if not (first_byte & PACKET_FIXED_BIT): raise ValueError("Packet fixed bit is zero") version = None packet_type = QuicPacketType.ONE_RTT destination_cid = buf.pull_bytes(host_cid_length) source_cid = b"" packet_end = buf.capacity return QuicHeader( version=version, packet_type=packet_type, packet_length=packet_end - packet_start, destination_cid=destination_cid, source_cid=source_cid, token=token, integrity_tag=integrity_tag, supported_versions=supported_versions, ) def encode_long_header_first_byte( version: int, packet_type: QuicPacketType, bits: int ) -> int: """ Encode the first byte of a long header packet. """ if version == QuicProtocolVersion.VERSION_2: long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_2 else: long_type_encode = PACKET_LONG_TYPE_ENCODE_VERSION_1 return ( PACKET_LONG_HEADER | PACKET_FIXED_BIT | long_type_encode[packet_type] << 4 | bits ) def encode_quic_retry( version: int, source_cid: bytes, destination_cid: bytes, original_destination_cid: bytes, retry_token: bytes, unused: int = 0, ) -> bytes: buf = Buffer( capacity=7 + len(destination_cid) + len(source_cid) + len(retry_token) + RETRY_INTEGRITY_TAG_SIZE ) buf.push_uint8(encode_long_header_first_byte(version, QuicPacketType.RETRY, unused)) buf.push_uint32(version) buf.push_uint8(len(destination_cid)) buf.push_bytes(destination_cid) buf.push_uint8(len(source_cid)) buf.push_bytes(source_cid) buf.push_bytes(retry_token) buf.push_bytes( get_retry_integrity_tag(buf.data, original_destination_cid, version=version) ) assert buf.eof() return buf.data def encode_quic_version_negotiation( source_cid: bytes, destination_cid: bytes, supported_versions: list[int] ) -> bytes: buf = Buffer( capacity=7 + len(destination_cid) + len(source_cid) + 4 * len(supported_versions) ) buf.push_uint8(os.urandom(1)[0] | PACKET_LONG_HEADER) buf.push_uint32(QuicProtocolVersion.NEGOTIATION) buf.push_uint8(len(destination_cid)) buf.push_bytes(destination_cid) buf.push_uint8(len(source_cid)) buf.push_bytes(source_cid) for version in supported_versions: buf.push_uint32(version) return buf.data # TLS EXTENSION @dataclass(**DATACLASS_KWARGS) class QuicPreferredAddress: ipv4_address: tuple[str, int] | None ipv6_address: tuple[str, int] | None connection_id: bytes stateless_reset_token: bytes @dataclass(**DATACLASS_KWARGS) class QuicVersionInformation: chosen_version: int available_versions: list[int] @dataclass() class QuicTransportParameters: original_destination_connection_id: bytes | None = None max_idle_timeout: int | None = None stateless_reset_token: bytes | None = None max_udp_payload_size: int | None = None initial_max_data: int | None = None initial_max_stream_data_bidi_local: int | None = None initial_max_stream_data_bidi_remote: int | None = None initial_max_stream_data_uni: int | None = None initial_max_streams_bidi: int | None = None initial_max_streams_uni: int | None = None ack_delay_exponent: int | None = None max_ack_delay: int | None = None disable_active_migration: bool | None = False preferred_address: QuicPreferredAddress | None = None active_connection_id_limit: int | None = None initial_source_connection_id: bytes | None = None retry_source_connection_id: bytes | None = None version_information: QuicVersionInformation | None = None max_datagram_frame_size: int | None = None quantum_readiness: bytes | None = None PARAMS = { 0x00: ("original_destination_connection_id", bytes), 0x01: ("max_idle_timeout", int), 0x02: ("stateless_reset_token", bytes), 0x03: ("max_udp_payload_size", int), 0x04: ("initial_max_data", int), 0x05: ("initial_max_stream_data_bidi_local", int), 0x06: ("initial_max_stream_data_bidi_remote", int), 0x07: ("initial_max_stream_data_uni", int), 0x08: ("initial_max_streams_bidi", int), 0x09: ("initial_max_streams_uni", int), 0x0A: ("ack_delay_exponent", int), 0x0B: ("max_ack_delay", int), 0x0C: ("disable_active_migration", bool), 0x0D: ("preferred_address", QuicPreferredAddress), 0x0E: ("active_connection_id_limit", int), 0x0F: ("initial_source_connection_id", bytes), 0x10: ("retry_source_connection_id", bytes), # https://datatracker.ietf.org/doc/html/rfc9368#section-3 0x11: ("version_information", QuicVersionInformation), # extensions 0x0020: ("max_datagram_frame_size", int), 0x0C37: ("quantum_readiness", bytes), } def pull_quic_preferred_address(buf: Buffer) -> QuicPreferredAddress: ipv4_address = None ipv4_host = buf.pull_bytes(4) ipv4_port = buf.pull_uint16() if ipv4_host != bytes(4): ipv4_address = (str(ipaddress.IPv4Address(ipv4_host)), ipv4_port) ipv6_address = None ipv6_host = buf.pull_bytes(16) ipv6_port = buf.pull_uint16() if ipv6_host != bytes(16): ipv6_address = (str(ipaddress.IPv6Address(ipv6_host)), ipv6_port) connection_id_length = buf.pull_uint8() connection_id = buf.pull_bytes(connection_id_length) stateless_reset_token = buf.pull_bytes(16) return QuicPreferredAddress( ipv4_address=ipv4_address, ipv6_address=ipv6_address, connection_id=connection_id, stateless_reset_token=stateless_reset_token, ) def push_quic_preferred_address( buf: Buffer, preferred_address: QuicPreferredAddress ) -> None: if preferred_address.ipv4_address is not None: buf.push_bytes(ipaddress.IPv4Address(preferred_address.ipv4_address[0]).packed) buf.push_uint16(preferred_address.ipv4_address[1]) else: buf.push_bytes(bytes(6)) if preferred_address.ipv6_address is not None: buf.push_bytes(ipaddress.IPv6Address(preferred_address.ipv6_address[0]).packed) buf.push_uint16(preferred_address.ipv6_address[1]) else: buf.push_bytes(bytes(18)) buf.push_uint8(len(preferred_address.connection_id)) buf.push_bytes(preferred_address.connection_id) buf.push_bytes(preferred_address.stateless_reset_token) def pull_quic_version_information(buf: Buffer, length: int) -> QuicVersionInformation: chosen_version = buf.pull_uint32() available_versions = [] for i in range(length // 4 - 1): available_versions.append(buf.pull_uint32()) # If an endpoint receives a Chosen Version equal to zero, or any Available Version # equal to zero, it MUST treat it as a parsing failure. # # https://datatracker.ietf.org/doc/html/rfc9368#section-4 if chosen_version == 0 or 0 in available_versions: raise ValueError("Version Information must not contain version 0") return QuicVersionInformation( chosen_version=chosen_version, available_versions=available_versions, ) def push_quic_version_information( buf: Buffer, version_information: QuicVersionInformation ) -> None: buf.push_uint32(version_information.chosen_version) for version in version_information.available_versions: buf.push_uint32(version) def pull_quic_transport_parameters(buf: Buffer) -> QuicTransportParameters: params = QuicTransportParameters() while not buf.eof(): param_id = buf.pull_uint_var() param_len = buf.pull_uint_var() param_start = buf.tell() if param_id in PARAMS: # parse known parameter param_name, param_type = PARAMS[param_id] if param_type is int: setattr(params, param_name, buf.pull_uint_var()) elif param_type is bytes: setattr(params, param_name, buf.pull_bytes(param_len)) elif param_type is QuicPreferredAddress: setattr(params, param_name, pull_quic_preferred_address(buf)) elif param_type is QuicVersionInformation: setattr( params, param_name, pull_quic_version_information(buf, param_len), ) else: setattr(params, param_name, True) else: # skip unknown parameter buf.pull_bytes(param_len) if buf.tell() != param_start + param_len: raise ValueError("Transport parameter length does not match") return params def push_quic_transport_parameters( buf: Buffer, params: QuicTransportParameters ) -> None: for param_id, (param_name, param_type) in PARAMS.items(): param_value = getattr(params, param_name) if param_value is not None and param_value is not False: param_buf = Buffer(capacity=65536) if param_type is int: param_buf.push_uint_var(param_value) elif param_type is bytes: param_buf.push_bytes(param_value) elif param_type is QuicPreferredAddress: push_quic_preferred_address(param_buf, param_value) elif param_type is QuicVersionInformation: push_quic_version_information(param_buf, param_value) buf.push_uint_var(param_id) buf.push_uint_var(param_buf.tell()) buf.push_bytes(param_buf.data) # FRAMES class QuicFrameType(IntEnum): PADDING = 0x00 PING = 0x01 ACK = 0x02 ACK_ECN = 0x03 RESET_STREAM = 0x04 STOP_SENDING = 0x05 CRYPTO = 0x06 NEW_TOKEN = 0x07 STREAM_BASE = 0x08 MAX_DATA = 0x10 MAX_STREAM_DATA = 0x11 MAX_STREAMS_BIDI = 0x12 MAX_STREAMS_UNI = 0x13 DATA_BLOCKED = 0x14 STREAM_DATA_BLOCKED = 0x15 STREAMS_BLOCKED_BIDI = 0x16 STREAMS_BLOCKED_UNI = 0x17 NEW_CONNECTION_ID = 0x18 RETIRE_CONNECTION_ID = 0x19 PATH_CHALLENGE = 0x1A PATH_RESPONSE = 0x1B TRANSPORT_CLOSE = 0x1C APPLICATION_CLOSE = 0x1D HANDSHAKE_DONE = 0x1E DATAGRAM = 0x30 DATAGRAM_WITH_LENGTH = 0x31 NON_ACK_ELICITING_FRAME_TYPES = frozenset( [ QuicFrameType.ACK, QuicFrameType.ACK_ECN, QuicFrameType.PADDING, QuicFrameType.TRANSPORT_CLOSE, QuicFrameType.APPLICATION_CLOSE, ] ) NON_IN_FLIGHT_FRAME_TYPES = frozenset( [ QuicFrameType.ACK, QuicFrameType.ACK_ECN, QuicFrameType.TRANSPORT_CLOSE, QuicFrameType.APPLICATION_CLOSE, ] ) PROBING_FRAME_TYPES = frozenset( [ QuicFrameType.PATH_CHALLENGE, QuicFrameType.PATH_RESPONSE, QuicFrameType.PADDING, QuicFrameType.NEW_CONNECTION_ID, ] ) @dataclass(**DATACLASS_KWARGS) class QuicResetStreamFrame: error_code: int final_size: int stream_id: int @dataclass(**DATACLASS_KWARGS) class QuicStopSendingFrame: error_code: int stream_id: int @dataclass(**DATACLASS_KWARGS) class QuicStreamFrame: data: bytes = b"" fin: bool = False offset: int = 0 def pull_ack_frame(buf: Buffer) -> tuple[RangeSet, int]: rangeset = RangeSet() end = buf.pull_uint_var() # largest acknowledged delay = buf.pull_uint_var() ack_range_count = buf.pull_uint_var() ack_count = buf.pull_uint_var() # first ack range rangeset.add(end - ack_count, end + 1) end -= ack_count for _ in range(ack_range_count): end -= buf.pull_uint_var() + 2 ack_count = buf.pull_uint_var() rangeset.add(end - ack_count, end + 1) end -= ack_count return rangeset, delay def push_ack_frame(buf: Buffer, rangeset: RangeSet, delay: int) -> int: ranges = len(rangeset) index = ranges - 1 r = rangeset[index] buf.push_uint_var(r[1] - 1) buf.push_uint_var(delay) buf.push_uint_var(index) buf.push_uint_var(r[1] - 1 - r[0]) start = r[0] while index > 0: index -= 1 r = rangeset[index] buf.push_uint_var(start - r[1] - 1) buf.push_uint_var(r[1] - r[0] - 1) start = r[0] return ranges